From 4b23670dd697307c1c1f02ca6a478db480e3aa79 Mon Sep 17 00:00:00 2001 From: "deepsource-autofix[bot]" <62050782+deepsource-autofix[bot]@users.noreply.github.com> Date: Sun, 21 Aug 2022 14:01:02 +0000 Subject: [PATCH] Format code with black, yapf, autopep8 and isort This commit fixes the style issues introduced in cc2f5f5 according to the output from black, yapf, autopep8 and isort. Details: https://deepsource.io/gh/shubham11941140/pandas/transform/bdcb7ef3-e591-4442-bb51-302ebe20235e/ --- pandas/core/computation/pytables.py | 68 +- pandas/core/frame.py | 1052 +++++++++-------- pandas/core/indexes/base.py | 924 ++++++++------- pandas/core/indexes/datetimelike.py | 47 +- pandas/core/indexes/range.py | 129 +- pandas/core/nanops.py | 172 +-- pandas/core/ops/__init__.py | 111 +- pandas/io/formats/excel.py | 127 +- pandas/io/formats/style.py | 463 ++++---- pandas/io/formats/style_render.py | 575 +++++---- pandas/io/json/_json.py | 171 ++- pandas/io/parsers/python_parser.py | 322 +++-- pandas/io/parsers/readers.py | 225 ++-- pandas/io/pytables.py | 696 +++++------ pandas/io/sas/sas7bdat.py | 266 ++--- pandas/io/sql.py | 194 +-- pandas/io/stata.py | 522 ++++---- pandas/tests/apply/test_frame_apply.py | 542 +++++---- pandas/tests/apply/test_frame_transform.py | 51 +- pandas/tests/apply/test_str.py | 24 +- .../tests/frame/methods/test_interpolate.py | 255 ++-- pandas/tests/indexes/test_setops.py | 78 +- .../tests/indexing/multiindex/test_slice.py | 494 ++++---- pandas/tests/io/json/test_pandas.py | 489 ++++---- pandas/tests/io/pytables/common.py | 14 +- pandas/tests/util/test_assert_series_equal.py | 33 +- pandas/tseries/frequencies.py | 40 +- pandas/tseries/holiday.py | 83 +- setup.py | 225 ++-- 29 files changed, 4347 insertions(+), 4045 deletions(-) diff --git a/pandas/core/computation/pytables.py b/pandas/core/computation/pytables.py index 53196a240628f..6db4a722d5b98 100644 --- a/pandas/core/computation/pytables.py +++ b/pandas/core/computation/pytables.py @@ -38,7 +38,7 @@ class PyTablesScope(_scope.Scope): - __slots__ = ("queryables",) + __slots__ = ("queryables", ) queryables: dict[str, Any] @@ -49,7 +49,9 @@ def __init__( local_dict=None, queryables: dict[str, Any] | None = None, ): - super().__init__(level + 1, global_dict=global_dict, local_dict=local_dict) + super().__init__(level + 1, + global_dict=global_dict, + local_dict=local_dict) self.queryables = queryables or {} @@ -87,6 +89,7 @@ def value(self): class Constant(Term): + def __init__(self, value, env: PyTablesScope, side=None, encoding=None): assert isinstance(env, PyTablesScope), type(env) super().__init__(value, env, side=side, encoding=encoding) @@ -103,7 +106,8 @@ class BinOp(ops.BinOp): queryables: dict[str, Any] condition: str | None - def __init__(self, op: str, lhs, rhs, queryables: dict[str, Any], encoding): + def __init__(self, op: str, lhs, rhs, queryables: dict[str, Any], + encoding): super().__init__(op, lhs, rhs) self.queryables = queryables self.encoding = encoding @@ -113,6 +117,7 @@ def _disallow_scalar_only_bool_ops(self): pass def prune(self, klass): + def pr(left, right): """create and return a new specialized BinOp from myself""" if left is None: @@ -137,9 +142,11 @@ def pr(left, right): elif isinstance(right, k): return right - return k( - self.op, left, right, queryables=self.queryables, encoding=self.encoding - ).evaluate() + return k(self.op, + left, + right, + queryables=self.queryables, + encoding=self.encoding).evaluate() left, right = self.lhs, self.rhs @@ -256,7 +263,8 @@ def stringify(value): # string quoting return TermValue(v, stringify(v), "string") else: - raise TypeError(f"Cannot compare {v} of type {type(v)} to {kind} column") + raise TypeError( + f"Cannot compare {v} of type {type(v)} to {kind} column") def convert_values(self): pass @@ -268,7 +276,8 @@ class FilterBinOp(BinOp): def __repr__(self) -> str: if self.filter is None: return "Filter: Not Initialized" - return pprint_thing(f"[Filter : [{self.filter[0]}] -> [{self.filter[1]}]") + return pprint_thing( + f"[Filter : [{self.filter[0]}] -> [{self.filter[1]}]") def invert(self): """invert the filter""" @@ -324,6 +333,7 @@ def generate_filter_op(self, invert: bool = False): class JointFilterBinOp(FilterBinOp): + def format(self): raise NotImplementedError("unable to collapse Joint Filters") @@ -332,6 +342,7 @@ def evaluate(self): class ConditionBinOp(BinOp): + def __repr__(self) -> str: return pprint_thing(f"[Condition : [{self.condition}]]") @@ -341,8 +352,7 @@ def invert(self): # self.condition = "~(%s)" % self.condition # return self raise NotImplementedError( - "cannot use an invert condition when passing to numexpr" - ) + "cannot use an invert condition when passing to numexpr") def format(self): """return the actual ne format""" @@ -378,12 +388,14 @@ def evaluate(self): class JointConditionBinOp(ConditionBinOp): + def evaluate(self): self.condition = f"({self.lhs.condition} {self.op} {self.rhs.condition})" return self class UnaryOp(ops.UnaryOp): + def prune(self, klass): if self.op != "~": @@ -392,13 +404,11 @@ def prune(self, klass): operand = self.operand operand = operand.prune(klass) - if operand is not None and ( - issubclass(klass, ConditionBinOp) - and operand.condition is not None - or not issubclass(klass, ConditionBinOp) - and issubclass(klass, FilterBinOp) - and operand.filter is not None - ): + if operand is not None and (issubclass(klass, ConditionBinOp) + and operand.condition is not None + or not issubclass(klass, ConditionBinOp) + and issubclass(klass, FilterBinOp) + and operand.filter is not None): return operand.invert() return None @@ -429,9 +439,9 @@ def visit_Index(self, node, **kwargs): return self.visit(node.value).value def visit_Assign(self, node, **kwargs): - cmpr = ast.Compare( - ops=[ast.Eq()], left=node.targets[0], comparators=[node.value] - ) + cmpr = ast.Compare(ops=[ast.Eq()], + left=node.targets[0], + comparators=[node.value]) return self.visit(cmpr) def visit_Subscript(self, node, **kwargs): @@ -452,8 +462,7 @@ def visit_Subscript(self, node, **kwargs): return self.const_type(value[slobj], self.env) except TypeError as err: raise ValueError( - f"cannot subscript {repr(value)} with {repr(slobj)}" - ) from err + f"cannot subscript {repr(value)} with {repr(slobj)}") from err def visit_Attribute(self, node, **kwargs): attr = node.attr @@ -506,10 +515,8 @@ def _validate_where(w): TypeError : An invalid data type was passed in for w (e.g. dict). """ if not (isinstance(w, (PyTablesExpr, str)) or is_list_like(w)): - raise TypeError( - "where must be passed as a string, PyTablesExpr, " - "or list-like of PyTablesExpr" - ) + raise TypeError("where must be passed as a string, PyTablesExpr, " + "or list-like of PyTablesExpr") return w @@ -608,15 +615,13 @@ def evaluate(self): except AttributeError as err: raise ValueError( f"cannot process expression [{self.expr}], [{self}] " - "is not a valid condition" - ) from err + "is not a valid condition") from err try: self.filter = self.terms.prune(FilterBinOp) except AttributeError as err: raise ValueError( f"cannot process expression [{self.expr}], [{self}] " - "is not a valid filter" - ) from err + "is not a valid filter") from err return self.condition, self.filter @@ -647,7 +652,8 @@ def maybe_expression(s) -> bool: """loose checking if s is a pytables-acceptable expression""" if not isinstance(s, str): return False - ops = PyTablesExprVisitor.binary_ops + PyTablesExprVisitor.unary_ops + ("=",) + ops = PyTablesExprVisitor.binary_ops + PyTablesExprVisitor.unary_ops + ( + "=", ) # make sure we have an op at least return any(op in s for op in ops) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index deb117d70f9df..746c4903bd9a2 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -224,16 +224,22 @@ # Docstring templates _shared_doc_kwargs = { - "axes": "index, columns", - "klass": "DataFrame", - "axes_single_arg": "{0 or 'index', 1 or 'columns'}", - "axis": """axis : {0 or 'index', 1 or 'columns'}, default 0 + "axes": + "index, columns", + "klass": + "DataFrame", + "axes_single_arg": + "{0 or 'index', 1 or 'columns'}", + "axis": + """axis : {0 or 'index', 1 or 'columns'}, default 0 If 0 or 'index': apply function to each column. If 1 or 'columns': apply function to each row.""", - "inplace": """ + "inplace": + """ inplace : bool, default False If True, performs operation inplace and returns None.""", - "optional_by": """ + "optional_by": + """ by : str or list of str Name or list of names to sort by. @@ -241,12 +247,15 @@ levels and/or column labels. - if `axis` is 1 or `'columns'` then `by` may contain column levels and/or index labels.""", - "optional_labels": """labels : array-like, optional + "optional_labels": + """labels : array-like, optional New labels / index to conform the axis specified by 'axis' to.""", - "optional_axis": """axis : int or str, optional + "optional_axis": + """axis : int or str, optional Axis to target. Can be either the axis name ('index', 'columns') or number (0, 1).""", - "replace_iloc": """ + "replace_iloc": + """ This differs from updating with ``.loc`` or ``.iloc``, which require you to specify a location to update with some value.""", } @@ -453,7 +462,6 @@ 3 bar 8 """ - # ----------------------------------------------------------------------- # DataFrame class @@ -619,11 +627,9 @@ def __init__( if isinstance(data, dict): # retain pre-GH#38939 default behavior copy = True - elif ( - manager == "array" - and isinstance(data, (np.ndarray, ExtensionArray)) - and data.ndim == 2 - ): + elif (manager == "array" + and isinstance(data, (np.ndarray, ExtensionArray)) + and data.ndim == 2): # INFO(ArrayManager) by default copy the 2D input array to get # contiguous 1D arrays copy = True @@ -631,13 +637,22 @@ def __init__( copy = False if isinstance(data, (BlockManager, ArrayManager)): - mgr = self._init_mgr( - data, axes={"index": index, "columns": columns}, dtype=dtype, copy=copy - ) + mgr = self._init_mgr(data, + axes={ + "index": index, + "columns": columns + }, + dtype=dtype, + copy=copy) elif isinstance(data, dict): # GH#38939 de facto copy defaults to False only in non-dict cases - mgr = dict_to_mgr(data, index, columns, dtype=dtype, copy=copy, typ=manager) + mgr = dict_to_mgr(data, + index, + columns, + dtype=dtype, + copy=copy, + typ=manager) elif isinstance(data, ma.MaskedArray): import numpy.ma.mrecords as mrecords @@ -721,7 +736,8 @@ def __init__( # error: Argument 1 to "ensure_index" has incompatible type # "Collection[Any]"; expected "Union[Union[Union[ExtensionArray, # ndarray], Index, Series], Sequence[Any]]" - columns = ensure_index(columns) # type: ignore[arg-type] + columns = ensure_index( + columns) # type: ignore[arg-type] arrays, columns, index = nested_data_to_arrays( # error: Argument 3 to "nested_data_to_arrays" has incompatible # type "Optional[Collection[Any]]"; expected "Optional[Index]" @@ -779,7 +795,11 @@ def __init__( construct_1d_arraylike_from_scalar(data, len(index), dtype) for _ in range(len(columns)) ] - mgr = arrays_to_mgr(values, columns, index, dtype=None, typ=manager) + mgr = arrays_to_mgr(values, + columns, + index, + dtype=None, + typ=manager) else: arr2d = construct_2d_arraylike_from_scalar( data, @@ -901,8 +921,7 @@ def _can_fast_transpose(self) -> bool: # "_values" incompatible with return type "ndarray" in supertype "NDFrame" @property def _values( # type: ignore[override] - self, - ) -> np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray: + self, ) -> np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray: """ Analogue to ._values that may return a 2D ExtensionArray. """ @@ -911,7 +930,8 @@ def _values( # type: ignore[override] mgr = self._mgr if isinstance(mgr, ArrayManager): - if len(mgr.arrays) == 1 and not is_1d_only_ea_dtype(mgr.arrays[0].dtype): + if len(mgr.arrays) == 1 and not is_1d_only_ea_dtype( + mgr.arrays[0].dtype): # error: Item "ExtensionArray" of "Union[ndarray, ExtensionArray]" # has no attribute "reshape" return mgr.arrays[0].reshape(-1, 1) # type: ignore[union-attr] @@ -927,7 +947,8 @@ def _values( # type: ignore[override] return self.values # more generally, whatever we allow in NDArrayBackedExtensionBlock - arr = cast("np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray", arr) + arr = cast("np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray", + arr) return arr.T # ---------------------------------------------------------------------- @@ -957,8 +978,7 @@ def _repr_fits_horizontal_(self, ignore_width: bool = False) -> bool: # exceed max columns if (max_columns and nb_columns > max_columns) or ( - (not ignore_width) and width and nb_columns > (width // 2) - ): + (not ignore_width) and width and nb_columns > (width // 2)): return False # used by repr_html under IPython notebook or scripts ignore terminal @@ -966,7 +986,8 @@ def _repr_fits_horizontal_(self, ignore_width: bool = False) -> bool: if ignore_width or not console.in_interactive_session(): return True - if get_option("display.width") is not None or console.in_ipython_frontend(): + if get_option( + "display.width") is not None or console.in_ipython_frontend(): # check at least the column row for excessive width max_rows = 1 else: @@ -983,7 +1004,7 @@ def _repr_fits_horizontal_(self, ignore_width: bool = False) -> bool: if max_rows is not None: # unlimited rows # min of two, where one may be None - d = d.iloc[: min(max_rows, len(d))] + d = d.iloc[:min(max_rows, len(d))] else: return True @@ -998,9 +1019,8 @@ def _info_repr(self) -> bool: True if the repr should show the info view. """ info_repr_option = get_option("display.large_repr") == "info" - return info_repr_option and not ( - self._repr_fits_horizontal_() and self._repr_fits_vertical_() - ) + return info_repr_option and not (self._repr_fits_horizontal_() + and self._repr_fits_vertical_()) def __repr__(self) -> str: """ @@ -1118,7 +1138,8 @@ def to_string( "every integers corresponds with one column. If a dict is given, the key " "references the column, while the value defines the space to use.", ) - @Substitution(shared_params=fmt.common_docstring, returns=fmt.return_docstring) + @Substitution(shared_params=fmt.common_docstring, + returns=fmt.return_docstring) def to_string( self, buf: FilePath | WriteBuffer[str] | None = None, @@ -1217,9 +1238,7 @@ def style(self) -> Styler: return Styler(self) - _shared_docs[ - "items" - ] = r""" + _shared_docs["items"] = r""" Iterate over (column name, Series) pairs. Iterates over the DataFrame columns, returning a tuple with @@ -1334,9 +1353,9 @@ def iterrows(self) -> Iterable[tuple[Hashable, Series]]: s = klass(v, index=columns, name=k) yield k, s - def itertuples( - self, index: bool = True, name: str | None = "Pandas" - ) -> Iterable[tuple[Any, ...]]: + def itertuples(self, + index: bool = True, + name: str | None = "Pandas") -> Iterable[tuple[Any, ...]]: """ Iterate over DataFrame rows as namedtuples. @@ -1411,8 +1430,7 @@ def itertuples( # https://github.com/python/mypy/issues/9046 # error: namedtuple() expects a string literal as the first argument itertuple = collections.namedtuple( # type: ignore[misc] - name, fields, rename=True - ) + name, fields, rename=True) return map(itertuple._make, zip(*arrays)) # fallback to regular tuples @@ -1511,7 +1529,8 @@ def dot(self, other: AnyArrayLike | DataFrame) -> DataFrame | Series: """ if isinstance(other, (Series, DataFrame)): common = self.columns.union(other.index) - if len(common) > len(self.columns) or len(common) > len(other.index): + if len(common) > len(self.columns) or len(common) > len( + other.index): raise ValueError("matrices are not aligned") left = self.reindex(columns=common, copy=False) @@ -1528,11 +1547,12 @@ def dot(self, other: AnyArrayLike | DataFrame) -> DataFrame | Series: ) if isinstance(other, DataFrame): - return self._constructor( - np.dot(lvals, rvals), index=left.index, columns=other.columns - ) + return self._constructor(np.dot(lvals, rvals), + index=left.index, + columns=other.columns) elif isinstance(other, Series): - return self._constructor_sliced(np.dot(lvals, rvals), index=left.index) + return self._constructor_sliced(np.dot(lvals, rvals), + index=left.index) elif isinstance(rvals, (np.ndarray, Index)): result = np.dot(lvals, rvals) if result.ndim == 2: @@ -1548,13 +1568,13 @@ def __matmul__(self, other: Series) -> Series: @overload def __matmul__( - self, other: AnyArrayLike | DataFrame | Series - ) -> DataFrame | Series: + self, + other: AnyArrayLike | DataFrame | Series) -> DataFrame | Series: ... def __matmul__( - self, other: AnyArrayLike | DataFrame | Series - ) -> DataFrame | Series: + self, + other: AnyArrayLike | DataFrame | Series) -> DataFrame | Series: """ Matrix multiplication using binary `@` operator in Python>=3.5. """ @@ -1677,7 +1697,8 @@ def from_dict( data, index = list(data.values()), list(data.keys()) elif orient in ("columns", "tight"): if columns is not None: - raise ValueError(f"cannot use columns parameter with orient='{orient}'") + raise ValueError( + f"cannot use columns parameter with orient='{orient}'") else: # pragma: no cover raise ValueError("only recognize index or columns for orient") @@ -1876,12 +1897,12 @@ def to_dict(self, orient: str = "dict", into=dict): orient = orient.lower() # GH32515 if orient.startswith(("d", "l", "s", "r", "i")) and orient not in { - "dict", - "list", - "series", - "split", - "records", - "index", + "dict", + "list", + "series", + "split", + "records", + "index", }: warnings.warn( "Using short name for 'orient' is deprecated. Only the " @@ -1912,57 +1933,51 @@ def to_dict(self, orient: str = "dict", into=dict): return into_c((k, v.tolist()) for k, v in self.items()) elif orient == "split": - return into_c( + return into_c(( + ("index", self.index.tolist()), + ("columns", self.columns.tolist()), ( - ("index", self.index.tolist()), - ("columns", self.columns.tolist()), - ( - "data", - [ - list(map(maybe_box_native, t)) - for t in self.itertuples(index=False, name=None) - ], - ), - ) - ) + "data", + [ + list(map(maybe_box_native, t)) + for t in self.itertuples(index=False, name=None) + ], + ), + )) elif orient == "tight": - return into_c( + return into_c(( + ("index", self.index.tolist()), + ("columns", self.columns.tolist()), ( - ("index", self.index.tolist()), - ("columns", self.columns.tolist()), - ( - "data", - [ - list(map(maybe_box_native, t)) - for t in self.itertuples(index=False, name=None) - ], - ), - ("index_names", list(self.index.names)), - ("column_names", list(self.columns.names)), - ) - ) + "data", + [ + list(map(maybe_box_native, t)) + for t in self.itertuples(index=False, name=None) + ], + ), + ("index_names", list(self.index.names)), + ("column_names", list(self.columns.names)), + )) elif orient == "series": return into_c((k, v) for k, v in self.items()) elif orient == "records": columns = self.columns.tolist() - rows = ( - dict(zip(columns, row)) - for row in self.itertuples(index=False, name=None) - ) + rows = (dict(zip(columns, row)) + for row in self.itertuples(index=False, name=None)) return [ - into_c((k, maybe_box_native(v)) for k, v in row.items()) for row in rows + into_c((k, maybe_box_native(v)) for k, v in row.items()) + for row in rows ] elif orient == "index": if not self.index.is_unique: - raise ValueError("DataFrame index must be unique for orient='index'.") - return into_c( - (t[0], dict(zip(self.columns, t[1:]))) - for t in self.itertuples(name=None) - ) + raise ValueError( + "DataFrame index must be unique for orient='index'.") + return into_c((t[0], dict(zip(self.columns, t[1:]))) + for t in self.itertuples(name=None)) else: raise ValueError(f"orient '{orient}' not understood") @@ -2162,8 +2177,8 @@ def from_records( columns = ensure_index(columns) def maybe_reorder( - arrays: list[ArrayLike], arr_columns: Index, columns: Index, index - ) -> tuple[list[ArrayLike], Index, Index | None]: + arrays: list[ArrayLike], arr_columns: Index, columns: Index, + index) -> tuple[list[ArrayLike], Index, Index | None]: """ If our desired 'columns' do not match the data's pre-existing 'arr_columns', we re-order our arrays. This is like a pre-emptive (cheap) reindex. @@ -2178,7 +2193,8 @@ def maybe_reorder( # for backward compat use an object Index instead of RangeIndex result_index = Index([]) - arrays, arr_columns = reorder_arrays(arrays, arr_columns, columns, length) + arrays, arr_columns = reorder_arrays(arrays, arr_columns, columns, + length) return arrays, arr_columns, result_index if is_iterator(data): @@ -2220,8 +2236,7 @@ def maybe_reorder( arr_columns = Index(arr_columns_list) arrays, arr_columns, result_index = maybe_reorder( - arrays, arr_columns, columns, index - ) + arrays, arr_columns, columns, index) elif isinstance(data, (np.ndarray, DataFrame)): arrays, columns = to_arrays(data, columns) @@ -2244,8 +2259,7 @@ def maybe_reorder( columns = arr_columns else: arrays, arr_columns, result_index = maybe_reorder( - arrays, arr_columns, columns, index - ) + arrays, arr_columns, columns, index) if exclude is None: exclude = set() @@ -2262,12 +2276,15 @@ def maybe_reorder( result_index = Index([], name=index) else: try: - index_data = [arrays[arr_columns.get_loc(field)] for field in index] + index_data = [ + arrays[arr_columns.get_loc(field)] for field in index + ] except (KeyError, TypeError): # raised by get_loc, see GH#29258 result_index = index else: - result_index = ensure_index_from_sequences(index_data, names=index) + result_index = ensure_index_from_sequences(index_data, + names=index) exclude.update(index) if any(exclude): @@ -2282,9 +2299,10 @@ def maybe_reorder( return cls(mgr) - def to_records( - self, index=True, column_dtypes=None, index_dtypes=None - ) -> np.recarray: + def to_records(self, + index=True, + column_dtypes=None, + index_dtypes=None) -> np.recarray: """ Convert DataFrame to a NumPy record array. @@ -2385,9 +2403,14 @@ def to_records( elif index_names[0] is None: index_names = ["index"] - names = [str(name) for name in itertools.chain(index_names, self.columns)] + names = [ + str(name) + for name in itertools.chain(index_names, self.columns) + ] else: - arrays = [np.asarray(self.iloc[:, i]) for i in range(len(self.columns))] + arrays = [ + np.asarray(self.iloc[:, i]) for i in range(len(self.columns)) + ] names = [str(c) for c in self.columns] index_names = [] @@ -2443,7 +2466,11 @@ def to_records( msg = f"Invalid dtype {dtype_mapping} specified for {element} {name}" raise ValueError(msg) - return np.rec.fromarrays(arrays, dtype={"names": names, "formats": formats}) + return np.rec.fromarrays(arrays, + dtype={ + "names": names, + "formats": formats + }) @classmethod def _from_arrays( @@ -2619,7 +2646,8 @@ def to_stata( >>> df.to_stata('animals.dta') # doctest: +SKIP """ if version not in (114, 117, 118, 119, None): - raise ValueError("Only formats 114, 117, 118 and 119 are supported.") + raise ValueError( + "Only formats 114, 117, 118 and 119 are supported.") if version == 114: if convert_strl is not None: raise ValueError("strl is not supported in format 114") @@ -2660,7 +2688,8 @@ def to_stata( writer.write_file() @deprecate_kwarg(old_arg_name="fname", new_arg_name="path") - def to_feather(self, path: FilePath | WriteBuffer[bytes], **kwargs) -> None: + def to_feather(self, path: FilePath | WriteBuffer[bytes], + **kwargs) -> None: """ Write a DataFrame to the binary Feather format. @@ -2860,7 +2889,8 @@ def to_parquet( " .. versionadded:: 0.25.0\n" " Ability to use str", ) - @Substitution(shared_params=fmt.common_docstring, returns=fmt.return_docstring) + @Substitution(shared_params=fmt.common_docstring, + returns=fmt.return_docstring) def to_html( self, buf: FilePath | WriteBuffer[str] | None = None, @@ -2949,11 +2979,13 @@ def to_html( @doc( storage_options=_shared_docs["storage_options"], - compression_options=_shared_docs["compression_options"] % "path_or_buffer", + compression_options=_shared_docs["compression_options"] % + "path_or_buffer", ) def to_xml( self, - path_or_buffer: FilePath | WriteBuffer[bytes] | WriteBuffer[str] | None = None, + path_or_buffer: FilePath | WriteBuffer[bytes] | WriteBuffer[str] + | None = None, index: bool = True, root_name: str | None = "data", row_name: str | None = "row", @@ -2966,7 +2998,8 @@ def to_xml( xml_declaration: bool | None = True, pretty_print: bool | None = True, parser: str | None = "lxml", - stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] | None = None, + stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] + | None = None, compression: CompressionOptions = "infer", storage_options: StorageOptions = None, ) -> str | None: @@ -3122,8 +3155,7 @@ def to_xml( TreeBuilder = LxmlXMLFormatter else: raise ImportError( - "lxml not found, please install or use the etree parser." - ) + "lxml not found, please install or use the etree parser.") elif parser == "etree": TreeBuilder = EtreeXMLFormatter @@ -3165,7 +3197,8 @@ def info( ) -> None: if null_counts is not None: if show_counts is not None: - raise ValueError("null_counts used with show_counts. Use show_counts.") + raise ValueError( + "null_counts used with show_counts. Use show_counts.") warnings.warn( "null_counts is deprecated. Use show_counts instead", FutureWarning, @@ -3273,13 +3306,15 @@ def memory_usage(self, index: bool = True, deep: bool = False) -> Series: 5244 """ result = self._constructor_sliced( - [c.memory_usage(index=False, deep=deep) for col, c in self.items()], + [ + c.memory_usage(index=False, deep=deep) + for col, c in self.items() + ], index=self.columns, ) if index: index_memory_usage = self._constructor_sliced( - self.index.memory_usage(deep=deep), index=["Index"] - ) + self.index.memory_usage(deep=deep), index=["Index"]) result = index_memory_usage._append(result) return result @@ -3391,26 +3426,31 @@ def transpose(self, *args, copy: bool = False) -> DataFrame: if copy: new_vals = new_vals.copy() - result = self._constructor(new_vals, index=self.columns, columns=self.index) + result = self._constructor(new_vals, + index=self.columns, + columns=self.index) - elif ( - self._is_homogeneous_type and dtypes and is_extension_array_dtype(dtypes[0]) - ): + elif (self._is_homogeneous_type and dtypes + and is_extension_array_dtype(dtypes[0])): # We have EAs with the same dtype. We can preserve that dtype in transpose. dtype = dtypes[0] arr_type = dtype.construct_array_type() values = self.values - new_values = [arr_type._from_sequence(row, dtype=dtype) for row in values] - result = type(self)._from_arrays( - new_values, index=self.columns, columns=self.index - ) + new_values = [ + arr_type._from_sequence(row, dtype=dtype) for row in values + ] + result = type(self)._from_arrays(new_values, + index=self.columns, + columns=self.index) else: new_arr = self.values.T if copy: new_arr = new_arr.copy() - result = self._constructor(new_arr, index=self.columns, columns=self.index) + result = self._constructor(new_arr, + index=self.columns, + columns=self.index) return result.__finalize__(self, method="transpose") @@ -3437,7 +3477,8 @@ def _ixs(self, i: int, axis: int = 0): new_values = self._mgr.fast_xs(i) # if we are a copy, mark as such - copy = isinstance(new_values, np.ndarray) and new_values.base is None + copy = isinstance(new_values, + np.ndarray) and new_values.base is None result = self._constructor_sliced( new_values, index=self.columns, @@ -3491,8 +3532,7 @@ def __getitem__(self, key): if indexer is not None: if isinstance(indexer, np.ndarray): indexer = lib.maybe_indices_to_slice( - indexer.astype(np.intp, copy=False), len(self) - ) + indexer.astype(np.intp, copy=False), len(self)) if isinstance(indexer, np.ndarray): # GH#43223 If we can not convert, use take return self.take(indexer, axis=0) @@ -3554,8 +3594,7 @@ def _getitem_bool_array(self, key): ) elif len(key) != len(self.index): raise ValueError( - f"Item wrong length {len(key)} instead of {len(self.index)}." - ) + f"Item wrong length {len(key)} instead of {len(self.index)}.") # check_bool_indexer will throw exception if Series key cannot # be reindexed to match DataFrame rows @@ -3574,9 +3613,9 @@ def _getitem_multilevel(self, key): result.columns = result_columns else: new_values = self.values[:, loc] - result = self._constructor( - new_values, index=self.index, columns=result_columns - ) + result = self._constructor(new_values, + index=self.index, + columns=result_columns) result = result.__finalize__(self) # If there is only one column being returned, and its name is @@ -3592,9 +3631,9 @@ def _getitem_multilevel(self, key): if top == "": result = result[""] if isinstance(result, Series): - result = self._constructor_sliced( - result, index=self.index, name=key - ) + result = self._constructor_sliced(result, + index=self.index, + name=key) result._set_is_copy(self) return result @@ -3656,11 +3695,8 @@ def __setitem__(self, key, value): self._setitem_array(key, value) elif isinstance(value, DataFrame): self._set_item_frame_value(key, value) - elif ( - is_list_like(value) - and not self.columns.is_unique - and 1 < len(self.columns.get_indexer_for([key])) == len(value) - ): + elif (is_list_like(value) and not self.columns.is_unique + and 1 < len(self.columns.get_indexer_for([key])) == len(value)): # Column to set is duplicated self._setitem_array([key], value) else: @@ -3764,13 +3800,13 @@ def _setitem_frame(self, key, value): # df[df > df2] = 0 if isinstance(key, np.ndarray): if key.shape != self.shape: - raise ValueError("Array conditional must be same shape as self") + raise ValueError( + "Array conditional must be same shape as self") key = self._constructor(key, **self._construct_axes_dict()) if key.size and not is_bool_dtype(key.values): raise TypeError( - "Must pass DataFrame or 2-d ndarray with boolean values only" - ) + "Must pass DataFrame or 2-d ndarray with boolean values only") self._check_inplace_setting(value) self._check_setitem_copy() @@ -3790,8 +3826,7 @@ def _set_item_frame_value(self, key, value: DataFrame) -> None: # align right-hand-side columns if self.columns # is multi-index and self[key] is a sub-frame if isinstance(self.columns, MultiIndex) and isinstance( - loc, (slice, Series, np.ndarray, Index) - ): + loc, (slice, Series, np.ndarray, Index)): cols = maybe_droplevels(cols, key) if len(cols) and not cols.equals(value.columns): value = value.reindex(cols, axis=1) @@ -3800,9 +3835,10 @@ def _set_item_frame_value(self, key, value: DataFrame) -> None: arraylike = _reindex_for_setitem(value, self.index) self._set_item_mgr(key, arraylike) - def _iset_item_mgr( - self, loc: int | slice | np.ndarray, value, inplace: bool = False - ) -> None: + def _iset_item_mgr(self, + loc: int | slice | np.ndarray, + value, + inplace: bool = False) -> None: # when called from _set_item_mgr loc can be anything returned from get_loc self._mgr.iset(loc, value, inplace=inplace) self._clear_item_cache() @@ -3844,22 +3880,22 @@ def _set_item(self, key, value) -> None: """ value = self._sanitize_column(value) - if ( - key in self.columns - and value.ndim == 1 - and not is_extension_array_dtype(value) - ): + if (key in self.columns and value.ndim == 1 + and not is_extension_array_dtype(value)): # broadcast across multiple columns if necessary - if not self.columns.is_unique or isinstance(self.columns, MultiIndex): + if not self.columns.is_unique or isinstance( + self.columns, MultiIndex): existing_piece = self[key] if isinstance(existing_piece, DataFrame): value = np.tile(value, (len(existing_piece.columns), 1)).T self._set_item_mgr(key, value) - def _set_value( - self, index: IndexLabel, col, value: Scalar, takeable: bool = False - ) -> None: + def _set_value(self, + index: IndexLabel, + col, + value: Scalar, + takeable: bool = False) -> None: """ Put single value at passed column and index. @@ -3914,7 +3950,9 @@ def _ensure_valid_index(self, value) -> None: if self.index.name is not None: index_copy.name = self.index.name - self._mgr = self._mgr.reindex_axis(index_copy, axis=1, fill_value=np.nan) + self._mgr = self._mgr.reindex_axis(index_copy, + axis=1, + fill_value=np.nan) def _box_col_values(self, values: SingleDataManager, loc: int) -> Series: """ @@ -4336,14 +4374,15 @@ def select_dtypes(self, include=None, exclude=None) -> DataFrame: 5 False 2.0 """ if not is_list_like(include): - include = (include,) if include is not None else () + include = (include, ) if include is not None else () if not is_list_like(exclude): - exclude = (exclude,) if exclude is not None else () + exclude = (exclude, ) if exclude is not None else () selection = (frozenset(include), frozenset(exclude)) if not any(selection): - raise ValueError("at least one of include or exclude must be nonempty") + raise ValueError( + "at least one of include or exclude must be nonempty") # convert the myriad valid dtypes object to a single representation def check_int_infer_dtype(dtypes): @@ -4351,7 +4390,8 @@ def check_int_infer_dtype(dtypes): for dtype in dtypes: # Numpy maps int to different types (int32, in64) on Windows and Linux # see https://github.com/numpy/numpy/issues/9464 - if (isinstance(dtype, str) and dtype == "int") or (dtype is int): + if (isinstance(dtype, str) + and dtype == "int") or (dtype is int): converted_dtypes.append(np.int32) converted_dtypes.append(np.int64) elif dtype == "float" or dtype is float: @@ -4369,12 +4409,13 @@ def check_int_infer_dtype(dtypes): # can't both include AND exclude! if not include.isdisjoint(exclude): - raise ValueError(f"include and exclude overlap on {(include & exclude)}") + raise ValueError( + f"include and exclude overlap on {(include & exclude)}") def dtype_predicate(dtype: DtypeObj, dtypes_set) -> bool: return issubclass(dtype.type, tuple(dtypes_set)) or ( - np.number in dtypes_set and getattr(dtype, "_is_numeric", False) - ) + np.number in dtypes_set + and getattr(dtype, "_is_numeric", False)) def predicate(arr: ArrayLike) -> bool: dtype = arr.dtype @@ -4446,10 +4487,8 @@ def insert( if allow_duplicates is lib.no_default: allow_duplicates = False if allow_duplicates and not self.flags.allows_duplicate_labels: - raise ValueError( - "Cannot specify 'allow_duplicates=True' when " - "'self.flags.allows_duplicate_labels' is False." - ) + raise ValueError("Cannot specify 'allow_duplicates=True' when " + "'self.flags.allows_duplicate_labels' is False.") if not allow_duplicates and column in self.columns: # Should this be a different kind of error?? raise ValueError(f"cannot insert {column}, already exists") @@ -4553,15 +4592,15 @@ def _sanitize_column(self, value) -> ArrayLike: @property def _series(self): return { - item: Series( - self._mgr.iget(idx), index=self.index, name=item, fastpath=True - ) + item: Series(self._mgr.iget(idx), + index=self.index, + name=item, + fastpath=True) for idx, item in enumerate(self.columns) } - def lookup( - self, row_labels: Sequence[IndexLabel], col_labels: Sequence[IndexLabel] - ) -> np.ndarray: + def lookup(self, row_labels: Sequence[IndexLabel], + col_labels: Sequence[IndexLabel]) -> np.ndarray: """ Label-based "fancy indexing" function for DataFrame. Given equal-length arrays of row and column labels, return an @@ -4585,12 +4624,10 @@ def lookup( numpy.ndarray The found values. """ - msg = ( - "The 'lookup' method is deprecated and will be " - "removed in a future version. " - "You can use DataFrame.melt and DataFrame.loc " - "as a substitute." - ) + msg = ("The 'lookup' method is deprecated and will be " + "removed in a future version. " + "You can use DataFrame.melt and DataFrame.loc " + "as a substitute.") warnings.warn(msg, FutureWarning, stacklevel=find_stack_level()) n = len(row_labels) @@ -4598,7 +4635,8 @@ def lookup( raise ValueError("Row labels must have same size as column labels") if not (self.index.is_unique and self.columns.is_unique): # GH#33041 - raise ValueError("DataFrame.lookup requires unique index and columns") + raise ValueError( + "DataFrame.lookup requires unique index and columns") thresh = 1000 if not self._is_mixed_type or n > thresh: @@ -4624,20 +4662,19 @@ def lookup( # ---------------------------------------------------------------------- # Reindexing and alignment - def _reindex_axes(self, axes, level, limit, tolerance, method, fill_value, copy): + def _reindex_axes(self, axes, level, limit, tolerance, method, fill_value, + copy): frame = self columns = axes["columns"] if columns is not None: - frame = frame._reindex_columns( - columns, method, copy, level, fill_value, limit, tolerance - ) + frame = frame._reindex_columns(columns, method, copy, level, + fill_value, limit, tolerance) index = axes["index"] if index is not None: - frame = frame._reindex_index( - index, method, copy, level, fill_value, limit, tolerance - ) + frame = frame._reindex_index(index, method, copy, level, + fill_value, limit, tolerance) return frame @@ -4651,9 +4688,11 @@ def _reindex_index( limit=None, tolerance=None, ): - new_index, indexer = self.index.reindex( - new_index, method=method, level=level, limit=limit, tolerance=tolerance - ) + new_index, indexer = self.index.reindex(new_index, + method=method, + level=level, + limit=limit, + tolerance=tolerance) return self._reindex_with_indexers( {0: [new_index, indexer]}, copy=copy, @@ -4671,9 +4710,11 @@ def _reindex_columns( limit=None, tolerance=None, ): - new_columns, indexer = self.columns.reindex( - new_columns, method=method, level=level, limit=limit, tolerance=tolerance - ) + new_columns, indexer = self.columns.reindex(new_columns, + method=method, + level=level, + limit=limit, + tolerance=tolerance) return self._reindex_with_indexers( {1: [new_columns, indexer]}, copy=copy, @@ -4681,9 +4722,8 @@ def _reindex_columns( allow_dups=False, ) - def _reindex_multi( - self, axes: dict[str, Index], copy: bool, fill_value - ) -> DataFrame: + def _reindex_multi(self, axes: dict[str, Index], copy: bool, + fill_value) -> DataFrame: """ We are guaranteed non-Nones in the axes. """ @@ -4698,11 +4738,18 @@ def _reindex_multi( # ensures that self.values is cheap. It may be worth making this # condition more specific. indexer = row_indexer, col_indexer - new_values = take_2d_multi(self.values, indexer, fill_value=fill_value) - return self._constructor(new_values, index=new_index, columns=new_columns) + new_values = take_2d_multi(self.values, + indexer, + fill_value=fill_value) + return self._constructor(new_values, + index=new_index, + columns=new_columns) else: return self._reindex_with_indexers( - {0: [new_index, row_indexer], 1: [new_columns, col_indexer]}, + { + 0: [new_index, row_indexer], + 1: [new_columns, col_indexer] + }, copy=copy, fill_value=fill_value, ) @@ -4735,9 +4782,10 @@ def align( ) @overload - def set_axis( - self, labels, axis: Axis = ..., inplace: Literal[False] = ... - ) -> DataFrame: + def set_axis(self, + labels, + axis: Axis = ..., + inplace: Literal[False] = ...) -> DataFrame: ... @overload @@ -4749,14 +4797,15 @@ def set_axis(self, labels, *, inplace: Literal[True]) -> None: ... @overload - def set_axis( - self, labels, axis: Axis = ..., inplace: bool = ... - ) -> DataFrame | None: + def set_axis(self, + labels, + axis: Axis = ..., + inplace: bool = ...) -> DataFrame | None: ... - @deprecate_nonkeyword_arguments(version=None, allowed_args=["self", "labels"]) - @Appender( - """ + @deprecate_nonkeyword_arguments(version=None, + allowed_args=["self", "labels"]) + @Appender(""" Examples -------- >>> df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) @@ -4785,8 +4834,7 @@ def set_axis( 0 1 4 1 2 5 2 3 6 - """ - ) + """) @Substitution( **_shared_doc_kwargs, extended_summary_sub=" column or", @@ -4811,14 +4859,16 @@ def set_axis(self, labels, axis: Axis = 0, inplace: bool = False): ], ) def reindex(self, *args, **kwargs) -> DataFrame: - axes = validate_axis_style_args(self, args, kwargs, "labels", "reindex") + axes = validate_axis_style_args(self, args, kwargs, "labels", + "reindex") kwargs.update(axes) # Pop these, since the values are in `kwargs` under different names kwargs.pop("axis", None) kwargs.pop("labels", None) return super().reindex(**kwargs) - @deprecate_nonkeyword_arguments(version=None, allowed_args=["self", "labels"]) + @deprecate_nonkeyword_arguments(version=None, + allowed_args=["self", "labels"]) def drop( self, labels=None, @@ -5221,7 +5271,8 @@ def fillna( ) -> DataFrame | None: ... - @deprecate_nonkeyword_arguments(version=None, allowed_args=["self", "value"]) + @deprecate_nonkeyword_arguments(version=None, + allowed_args=["self", "value"]) @doc(NDFrame.fillna, **_shared_doc_kwargs) def fillna( self, @@ -5303,9 +5354,8 @@ def replace( method=method, ) - def _replace_columnwise( - self, mapping: dict[Hashable, tuple[Any, Any]], inplace: bool, regex - ): + def _replace_columnwise(self, mapping: dict[Hashable, tuple[Any, Any]], + inplace: bool, regex): """ Dispatch to Series.replace column-wise. @@ -5366,45 +5416,39 @@ def shift( for col in range(min(ncols, abs(periods))): # Define filler inside loop so we get a copy filler = self.iloc[:, -1].shift(len(self)) - result.insert( - len(result.columns), label, filler, allow_duplicates=True - ) + result.insert(len(result.columns), + label, + filler, + allow_duplicates=True) result.columns = self.columns.copy() return result - elif ( - axis == 1 - and periods != 0 - and fill_value is not lib.no_default - and ncols > 0 - ): + elif (axis == 1 and periods != 0 and fill_value is not lib.no_default + and ncols > 0): arrays = self._mgr.arrays if len(arrays) > 1 or ( - # If we only have one block and we know that we can't - # keep the same dtype (i.e. the _can_hold_element check) - # then we can go through the reindex_indexer path - # (and avoid casting logic in the Block method). - # The exception to this (until 2.0) is datetimelike - # dtypes with integers, which cast. - not can_hold_element(arrays[0], fill_value) - # TODO(2.0): remove special case for integer-with-datetimelike - # once deprecation is enforced - and not ( - lib.is_integer(fill_value) and needs_i8_conversion(arrays[0].dtype) - ) - ): + # If we only have one block and we know that we can't + # keep the same dtype (i.e. the _can_hold_element check) + # then we can go through the reindex_indexer path + # (and avoid casting logic in the Block method). + # The exception to this (until 2.0) is datetimelike + # dtypes with integers, which cast. + not can_hold_element(arrays[0], fill_value) + # TODO(2.0): remove special case for integer-with-datetimelike + # once deprecation is enforced + and not (lib.is_integer(fill_value) + and needs_i8_conversion(arrays[0].dtype))): # GH#35488 we need to watch out for multi-block cases # We only get here with fill_value not-lib.no_default nper = abs(periods) nper = min(nper, ncols) if periods > 0: - indexer = np.array( - [-1] * nper + list(range(ncols - periods)), dtype=np.intp - ) + indexer = np.array([-1] * nper + + list(range(ncols - periods)), + dtype=np.intp) else: - indexer = np.array( - list(range(nper, ncols)) + [-1] * nper, dtype=np.intp - ) + indexer = np.array(list(range(nper, ncols)) + [-1] * nper, + dtype=np.intp) mgr = self._mgr.reindex_indexer( self.columns, indexer, @@ -5415,11 +5459,13 @@ def shift( res_df = self._constructor(mgr) return res_df.__finalize__(self, method="shift") - return super().shift( - periods=periods, freq=freq, axis=axis, fill_value=fill_value - ) + return super().shift(periods=periods, + freq=freq, + axis=axis, + fill_value=fill_value) - @deprecate_nonkeyword_arguments(version=None, allowed_args=["self", "keys"]) + @deprecate_nonkeyword_arguments(version=None, + allowed_args=["self", "keys"]) def set_index( self, keys, @@ -5522,15 +5568,14 @@ def set_index( if not isinstance(keys, list): keys = [keys] - err_msg = ( - 'The parameter "keys" may be a column key, one-dimensional ' - "array, or a list containing only valid column keys and " - "one-dimensional arrays." - ) + err_msg = ('The parameter "keys" may be a column key, one-dimensional ' + "array, or a list containing only valid column keys and " + "one-dimensional arrays.") missing: list[Hashable] = [] for col in keys: - if isinstance(col, (Index, Series, np.ndarray, list, abc.Iterator)): + if isinstance(col, + (Index, Series, np.ndarray, list, abc.Iterator)): # arrays are fine as long as they are one-dimensional # iterators get converted to list below if getattr(col, "ndim", 1) != 1: @@ -5600,8 +5645,7 @@ def set_index( # ensure_index_from_sequences would not raise for append=False. raise ValueError( f"Length mismatch: Expected {len(self)} rows, " - f"received array of length {len(arrays[-1])}" - ) + f"received array of length {len(arrays[-1])}") index = ensure_index_from_sequences(arrays, names) @@ -5692,7 +5736,8 @@ def reset_index( ) -> DataFrame | None: ... - @deprecate_nonkeyword_arguments(version=None, allowed_args=["self", "level"]) + @deprecate_nonkeyword_arguments(version=None, + allowed_args=["self", "level"]) def reset_index( self, level: Hashable | Sequence[Hashable] | None = None, @@ -5854,7 +5899,8 @@ class max type else: new_obj = self.copy() if allow_duplicates is not lib.no_default: - allow_duplicates = validate_bool_kwarg(allow_duplicates, "allow_duplicates") + allow_duplicates = validate_bool_kwarg(allow_duplicates, + "allow_duplicates") new_index = default_index(len(new_obj)) if level is not None: @@ -5871,8 +5917,9 @@ class max type to_insert = zip(self.index.levels, self.index.codes) else: default = "index" if "index" not in self else "level_0" - names = [default] if self.index.name is None else [self.index.name] - to_insert = ((self.index, None),) + names = [default + ] if self.index.name is None else [self.index.name] + to_insert = ((self.index, None), ) multi_col = isinstance(self.columns, MultiIndex) for i, (lev, lab) in reversed(list(enumerate(to_insert))): @@ -5880,13 +5927,13 @@ class max type continue name = names[i] if multi_col: - col_name = list(name) if isinstance(name, tuple) else [name] + col_name = list(name) if isinstance(name, + tuple) else [name] if col_fill is None: if len(col_name) not in (1, self.columns.nlevels): raise ValueError( "col_fill=None is incompatible " - f"with incomplete column name {name}" - ) + f"with incomplete column name {name}") col_fill = col_name[0] lev_num = self.columns._get_level_number(col_level) @@ -5902,9 +5949,10 @@ class max type if lab is not None: # if we have the codes, extract the values with a mask - level_values = algorithms.take( - level_values, lab, allow_fill=True, fill_value=lev._na_value - ) + level_values = algorithms.take(level_values, + lab, + allow_fill=True, + fill_value=lev._na_value) new_obj.insert( 0, @@ -6060,7 +6108,8 @@ def dropna( inplace = validate_bool_kwarg(inplace, "inplace") if isinstance(axis, (tuple, list)): # GH20987 - raise TypeError("supplying multiple axes to axis is no longer supported.") + raise TypeError( + "supplying multiple axes to axis is no longer supported.") axis = self._get_axis_number(axis) agg_axis = 1 - axis @@ -6102,7 +6151,8 @@ def dropna( else: return result - @deprecate_nonkeyword_arguments(version=None, allowed_args=["self", "subset"]) + @deprecate_nonkeyword_arguments(version=None, + allowed_args=["self", "subset"]) def drop_duplicates( self, subset: Hashable | Sequence[Hashable] | None = None, @@ -6306,13 +6356,9 @@ def f(vals) -> tuple[np.ndarray, int]: # Incompatible types in assignment (expression has type "Index", variable # has type "Sequence[Any]") subset = self.columns # type: ignore[assignment] - elif ( - not np.iterable(subset) - or isinstance(subset, str) - or isinstance(subset, tuple) - and subset in self.columns - ): - subset = (subset,) + elif (not np.iterable(subset) or isinstance(subset, str) + or isinstance(subset, tuple) and subset in self.columns): + subset = (subset, ) # needed for mypy since can't narrow types using np.iterable subset = cast(Sequence, subset) @@ -6340,7 +6386,8 @@ def f(vals) -> tuple[np.ndarray, int]: sort=False, xnull=False, ) - result = self._constructor_sliced(duplicated(ids, keep), index=self.index) + result = self._constructor_sliced(duplicated(ids, keep), + index=self.index) return result.__finalize__(self, method="duplicated") # ---------------------------------------------------------------------- @@ -6383,9 +6430,10 @@ def sort_values( # type: ignore[override] for (k, name) in zip(keys, by) ] - indexer = lexsort_indexer( - keys, orders=ascending, na_position=na_position, key=key - ) + indexer = lexsort_indexer(keys, + orders=ascending, + na_position=na_position, + key=key) elif len(by): # len(by) == 1 @@ -6401,20 +6449,21 @@ def sort_values( # type: ignore[override] if isinstance(ascending, (tuple, list)): ascending = ascending[0] - indexer = nargsort( - k, kind=kind, ascending=ascending, na_position=na_position, key=key - ) + indexer = nargsort(k, + kind=kind, + ascending=ascending, + na_position=na_position, + key=key) else: return self.copy() - new_data = self._mgr.take( - indexer, axis=self._get_block_manager_axis(axis), verify=False - ) + new_data = self._mgr.take(indexer, + axis=self._get_block_manager_axis(axis), + verify=False) if ignore_index: - new_data.set_axis( - self._get_block_manager_axis(axis), default_index(len(indexer)) - ) + new_data.set_axis(self._get_block_manager_axis(axis), + default_index(len(indexer))) result = self._constructor(new_data) if inplace: @@ -6656,13 +6705,15 @@ def value_counts( # Force MultiIndex for single column if len(subset) == 1: - counts.index = MultiIndex.from_arrays( - [counts.index], names=[counts.index.name] - ) + counts.index = MultiIndex.from_arrays([counts.index], + names=[counts.index.name]) return counts - def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame: + def nlargest(self, + n: int, + columns: IndexLabel, + keep: str = "first") -> DataFrame: """ Return the first `n` rows ordered by `columns` in descending order. @@ -6767,9 +6818,13 @@ def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFram Italy 59000000 1937894 IT Brunei 434000 12128 BN """ - return algorithms.SelectNFrame(self, n=n, keep=keep, columns=columns).nlargest() + return algorithms.SelectNFrame(self, n=n, keep=keep, + columns=columns).nlargest() - def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame: + def nsmallest(self, + n: int, + columns: IndexLabel, + keep: str = "first") -> DataFrame: """ Return the first `n` rows ordered by `columns` in ascending order. @@ -6865,20 +6920,16 @@ def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFra Anguilla 11300 311 AI Nauru 337000 182 NR """ - return algorithms.SelectNFrame( - self, n=n, keep=keep, columns=columns - ).nsmallest() + return algorithms.SelectNFrame(self, n=n, keep=keep, + columns=columns).nsmallest() @doc( Series.swaplevel, klass=_shared_doc_kwargs["klass"], - extra_params=dedent( - """axis : {0 or 'index', 1 or 'columns'}, default 0 + extra_params=dedent("""axis : {0 or 'index', 1 or 'columns'}, default 0 The axis to swap levels on. 0 or 'index' for row-wise, 1 or - 'columns' for column-wise.""" - ), - examples=dedent( - """\ + 'columns' for column-wise."""), + examples=dedent("""\ Examples -------- >>> df = pd.DataFrame( @@ -6928,15 +6979,18 @@ def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFra History Final exam January A Geography Final exam February B History Coursework March A - Geography Coursework April C""" - ), + Geography Coursework April C"""), ) - def swaplevel(self, i: Axis = -2, j: Axis = -1, axis: Axis = 0) -> DataFrame: + def swaplevel(self, + i: Axis = -2, + j: Axis = -1, + axis: Axis = 0) -> DataFrame: result = self.copy() axis = self._get_axis_number(axis) - if not isinstance(result._get_axis(axis), MultiIndex): # pragma: no cover + if not isinstance(result._get_axis(axis), + MultiIndex): # pragma: no cover raise TypeError("Can only swap levels on a hierarchical axis.") if axis == 0: @@ -6947,7 +7001,9 @@ def swaplevel(self, i: Axis = -2, j: Axis = -1, axis: Axis = 0) -> DataFrame: result.columns = result.columns.swaplevel(i, j) return result - def reorder_levels(self, order: Sequence[Axis], axis: Axis = 0) -> DataFrame: + def reorder_levels(self, + order: Sequence[Axis], + axis: Axis = 0) -> DataFrame: """ Rearrange index levels using input order. May not drop or duplicate levels. @@ -6989,7 +7045,8 @@ class diet Reptiles Snakes """ axis = self._get_axis_number(axis) - if not isinstance(self._get_axis(axis), MultiIndex): # pragma: no cover + if not isinstance(self._get_axis(axis), + MultiIndex): # pragma: no cover raise TypeError("Can only reorder levels on a hierarchical axis.") result = self.copy() @@ -7008,7 +7065,11 @@ class diet def _cmp_method(self, other, op): axis = 1 # only relevant for Series other case - self, other = ops.align_method_FRAME(self, other, axis, flex=False, level=None) + self, other = ops.align_method_FRAME(self, + other, + axis, + flex=False, + level=None) # See GH#4537 for discussion of scalar op behavior new_data = self._dispatch_frame_op(other, op, axis=axis) @@ -7019,16 +7080,23 @@ def _arith_method(self, other, op): return ops.frame_arith_method_with_reindex(self, other, op) axis = 1 # only relevant for Series other case - other = ops.maybe_prepare_scalar_for_op(other, (self.shape[axis],)) + other = ops.maybe_prepare_scalar_for_op(other, (self.shape[axis], )) - self, other = ops.align_method_FRAME(self, other, axis, flex=True, level=None) + self, other = ops.align_method_FRAME(self, + other, + axis, + flex=True, + level=None) new_data = self._dispatch_frame_op(other, op, axis=axis) return self._construct_result(new_data) _logical_method = _arith_method - def _dispatch_frame_op(self, right, func: Callable, axis: int | None = None): + def _dispatch_frame_op(self, + right, + func: Callable, + axis: int | None = None): """ Evaluate the frame operation func(left, right) by evaluating column-by-column, dispatching to the Series implementation. @@ -7093,15 +7161,19 @@ def _dispatch_frame_op(self, right, func: Callable, axis: int | None = None): right = right._values with np.errstate(all="ignore"): - arrays = [array_op(left, right) for left in self._iter_column_arrays()] + arrays = [ + array_op(left, right) + for left in self._iter_column_arrays() + ] else: # Remaining cases have less-obvious dispatch rules raise NotImplementedError(right) - return type(self)._from_arrays( - arrays, self.columns, self.index, verify_integrity=False - ) + return type(self)._from_arrays(arrays, + self.columns, + self.index, + verify_integrity=False) def _combine_frame(self, other: DataFrame, func, fill_value=None): # at this point we have `self._indexed_same(other)` @@ -7277,9 +7349,11 @@ def compare( keep_equal=keep_equal, ) - def combine( - self, other: DataFrame, func, fill_value=None, overwrite: bool = True - ) -> DataFrame: + def combine(self, + other: DataFrame, + func, + fill_value=None, + overwrite: bool = True) -> DataFrame: """ Perform column-wise combine with another DataFrame. @@ -7643,7 +7717,8 @@ def update( if join != "left": # pragma: no cover raise NotImplementedError("Only left join is supported") if errors not in ["ignore", "raise"]: - raise ValueError("The parameter errors must be either 'ignore' or 'raise'") + raise ValueError( + "The parameter errors must be either 'ignore' or 'raise'") if not isinstance(other, DataFrame): other = DataFrame(other) @@ -7676,8 +7751,7 @@ def update( # ---------------------------------------------------------------------- # Data reshaping - @Appender( - """ + @Appender(""" Examples -------- >>> df = pd.DataFrame({'Animal': ['Falcon', 'Falcon', @@ -7757,8 +7831,7 @@ def update( a 13.0 13.0 b 12.3 123.0 NaN 12.3 33.0 -""" - ) +""") @Appender(_shared_docs["groupby"] % _shared_doc_kwargs) def groupby( self, @@ -7776,10 +7849,8 @@ def groupby( if squeeze is not no_default: warnings.warn( - ( - "The `squeeze` parameter is deprecated and " - "will be removed in a future version." - ), + ("The `squeeze` parameter is deprecated and " + "will be removed in a future version."), FutureWarning, stacklevel=find_stack_level(), ) @@ -7806,9 +7877,7 @@ def groupby( dropna=dropna, ) - _shared_docs[ - "pivot" - ] = """ + _shared_docs["pivot"] = """ Return reshaped DataFrame organized by given index / column values. Reshape data (produce a "pivot" table) based on column values. Uses @@ -7959,9 +8028,7 @@ def pivot(self, index=None, columns=None, values=None) -> DataFrame: return pivot(self, index=index, columns=columns, values=values) - _shared_docs[ - "pivot_table" - ] = """ + _shared_docs["pivot_table"] = """ Create a spreadsheet-style pivot table as a DataFrame. The levels in the pivot table will be stored in MultiIndex objects @@ -8408,8 +8475,7 @@ def explode( if is_scalar(column) or isinstance(column, tuple): columns = [column] elif isinstance(column, list) and all( - map(lambda c: is_scalar(c) or isinstance(c, tuple), column) - ): + map(lambda c: is_scalar(c) or isinstance(c, tuple), column)): if not column: raise ValueError("column must be nonempty") if len(column) > len(set(column)): @@ -8426,7 +8492,8 @@ def explode( counts0 = self[columns[0]].apply(mylen) for c in columns[1:]: if not all(counts0 == self[c].apply(mylen)): - raise ValueError("columns must have matching element counts") + raise ValueError( + "columns must have matching element counts") result = DataFrame({c: df[c].explode() for c in columns}) result = df.drop(columns, axis=1).join(result) if ignore_index: @@ -8534,8 +8601,7 @@ def melt( extra_params="axis : {0 or 'index', 1 or 'columns'}, default 0\n " "Take difference over rows (0) or columns (1).\n", other_klass="Series", - examples=dedent( - """ + examples=dedent(""" Difference with previous row >>> df = pd.DataFrame({'a': [1, 2, 3, 4, 5, 6], @@ -8598,16 +8664,14 @@ def melt( >>> df.diff() a 0 NaN - 1 255.0""" - ), + 1 255.0"""), ) def diff(self, periods: int = 1, axis: Axis = 0) -> DataFrame: if not lib.is_integer(periods): - if not ( - is_float(periods) - # error: "int" has no attribute "is_integer" - and periods.is_integer() # type: ignore[attr-defined] - ): + if not (is_float(periods) + # error: "int" has no attribute "is_integer" + and periods.is_integer() # type: ignore[attr-defined] + ): raise ValueError("periods must be an integer") periods = int(periods) @@ -8646,8 +8710,7 @@ def _gotitem( # TODO: _shallow_copy(subset)? return subset[key] - _agg_summary_and_see_also_doc = dedent( - """ + _agg_summary_and_see_also_doc = dedent(""" The aggregation operations are always performed over an axis, either the index (default) or the column axis. This behavior is different from `numpy` aggregation functions (`mean`, `median`, `prod`, `sum`, `std`, @@ -8667,11 +8730,9 @@ def _gotitem( core.window.Expanding : Perform operations over expanding window. core.window.ExponentialMovingWindow : Perform operation over exponential weighted window. - """ - ) + """) - _agg_examples_doc = dedent( - """ + _agg_examples_doc = dedent(""" Examples -------- >>> df = pd.DataFrame([[1, 2, 3], @@ -8712,8 +8773,7 @@ def _gotitem( 2 8.0 3 NaN dtype: float64 - """ - ) + """) @doc( _shared_docs["aggregate"], @@ -8753,9 +8813,11 @@ def aggregate(self, func=None, axis: Axis = 0, *args, **kwargs): klass=_shared_doc_kwargs["klass"], axis=_shared_doc_kwargs["axis"], ) - def transform( - self, func: AggFuncType, axis: Axis = 0, *args, **kwargs - ) -> DataFrame: + def transform(self, + func: AggFuncType, + axis: Axis = 0, + *args, + **kwargs) -> DataFrame: from pandas.core.apply import frame_apply op = frame_apply(self, func=func, axis=axis, args=args, kwargs=kwargs) @@ -8764,13 +8826,13 @@ def transform( return result def apply( - self, - func: AggFuncType, - axis: Axis = 0, - raw: bool = False, - result_type=None, - args=(), - **kwargs, + self, + func: AggFuncType, + axis: Axis = 0, + raw: bool = False, + result_type=None, + args=(), + **kwargs, ): """ Apply a function along an axis of the DataFrame. @@ -8922,9 +8984,10 @@ def apply( ) return op.apply().__finalize__(self, method="apply") - def applymap( - self, func: PythonFuncType, na_action: str | None = None, **kwargs - ) -> DataFrame: + def applymap(self, + func: PythonFuncType, + na_action: str | None = None, + **kwargs) -> DataFrame: """ Apply a function to a Dataframe elementwise. @@ -8994,8 +9057,7 @@ def applymap( """ if na_action not in {"ignore", None}: raise ValueError( - f"na_action must be 'ignore' or None. Got {repr(na_action)}" - ) + f"na_action must be 'ignore' or None. Got {repr(na_action)}") ignore_na = na_action == "ignore" func = functools.partial(func, **kwargs) @@ -9003,7 +9065,9 @@ def applymap( def infer(x): if x.empty: return lib.map_infer(x, func, ignore_na=ignore_na) - return lib.map_infer(x.astype(object)._values, func, ignore_na=ignore_na) + return lib.map_infer(x.astype(object)._values, + func, + ignore_na=ignore_na) return self.apply(infer).__finalize__(self, "applymap") @@ -9133,13 +9197,13 @@ def _append( if isinstance(other, (Series, dict)): if isinstance(other, dict): if not ignore_index: - raise TypeError("Can only append a dict if ignore_index=True") + raise TypeError( + "Can only append a dict if ignore_index=True") other = Series(other) if other.name is None and not ignore_index: raise TypeError( "Can only append a Series if ignore_index=True " - "or if the Series has a name" - ) + "or if the Series has a name") index = Index([other.name], name=self.index.name) idx_diff = other.index.difference(self.columns) @@ -9169,11 +9233,8 @@ def _append( verify_integrity=verify_integrity, sort=sort, ) - if ( - combined_columns is not None - and not sort - and not combined_columns.equals(result.columns) - ): + if (combined_columns is not None and not sort + and not combined_columns.equals(result.columns)): # TODO: reindexing here is a kludge bc union_indexes does not # pass sort to index.union, xref #43375 # combined_columns.equals check is necessary for preserving dtype @@ -9335,9 +9396,12 @@ def join( 4 K0 A4 B0 5 K1 A5 B1 """ - return self._join_compat( - other, on=on, how=how, lsuffix=lsuffix, rsuffix=rsuffix, sort=sort - ) + return self._join_compat(other, + on=on, + how=how, + lsuffix=lsuffix, + rsuffix=rsuffix, + sort=sort) def _join_compat( self, @@ -9389,21 +9453,27 @@ def _join_compat( # join indexes only using concat if can_concat: if how == "left": - res = concat( - frames, axis=1, join="outer", verify_integrity=True, sort=sort - ) + res = concat(frames, + axis=1, + join="outer", + verify_integrity=True, + sort=sort) return res.reindex(self.index, copy=False) else: - return concat( - frames, axis=1, join=how, verify_integrity=True, sort=sort - ) + return concat(frames, + axis=1, + join=how, + verify_integrity=True, + sort=sort) joined = frames[0] for frame in frames[1:]: - joined = merge( - joined, frame, how=how, left_index=True, right_index=True - ) + joined = merge(joined, + frame, + how=how, + left_index=True, + right_index=True) return joined @@ -9442,9 +9512,10 @@ def merge( validate=validate, ) - def round( - self, decimals: int | dict[IndexLabel, int] | Series = 0, *args, **kwargs - ) -> DataFrame: + def round(self, + decimals: int | dict[IndexLabel, int] | Series = 0, + *args, + **kwargs) -> DataFrame: """ Round a DataFrame to a variable number of decimal places. @@ -9541,20 +9612,21 @@ def _series_round(ser: Series, decimals: int): if isinstance(decimals, Series) and not decimals.index.is_unique: raise ValueError("Index of decimals must be unique") if is_dict_like(decimals) and not all( - is_integer(value) for _, value in decimals.items() - ): + is_integer(value) for _, value in decimals.items()): raise TypeError("Values in decimals must be integers") new_cols = list(_dict_round(self, decimals)) elif is_integer(decimals): # Dispatch to Series.round new_cols = [_series_round(v, decimals) for _, v in self.items()] else: - raise TypeError("decimals must be an integer, a dict-like or a Series") + raise TypeError( + "decimals must be an integer, a dict-like or a Series") if len(new_cols) > 0: - return self._constructor( - concat(new_cols, axis=1), index=self.index, columns=self.columns - ).__finalize__(self, method="round") + return self._constructor(concat(new_cols, axis=1), + index=self.index, + columns=self.columns).__finalize__( + self, method="round") else: return self @@ -9658,15 +9730,15 @@ def corr( correl[i, j] = c correl[j, i] = c else: - raise ValueError( - "method must be either 'pearson', " - "'spearman', 'kendall', or a callable, " - f"'{method}' was supplied" - ) + raise ValueError("method must be either 'pearson', " + "'spearman', 'kendall', or a callable, " + f"'{method}' was supplied") return self._constructor(correl, index=idx, columns=cols) - def cov(self, min_periods: int | None = None, ddof: int | None = 1) -> DataFrame: + def cov(self, + min_periods: int | None = None, + ddof: int | None = 1) -> DataFrame: """ Compute pairwise covariance of columns, excluding NA/null values. @@ -9782,7 +9854,11 @@ def cov(self, min_periods: int | None = None, ddof: int | None = 1) -> DataFrame return self._constructor(base_cov, index=idx, columns=cols) - def corrwith(self, other, axis: Axis = 0, drop=False, method="pearson") -> Series: + def corrwith(self, + other, + axis: Axis = 0, + drop=False, + method="pearson") -> Series: """ Compute pairwise correlation. @@ -9860,8 +9936,7 @@ def corrwith(self, other, axis: Axis = 0, drop=False, method="pearson") -> Serie for i, r in enumerate(ndf): nonnull_mask = ~np.isnan(r) & ~np.isnan(k) corrs[numeric_cols[i]] = np.corrcoef( - r[nonnull_mask], k[nonnull_mask] - )[0, 1] + r[nonnull_mask], k[nonnull_mask])[0, 1] else: for i, r in enumerate(ndf): nonnull_mask = ~np.isnan(r) & ~np.isnan(k) @@ -9871,7 +9946,8 @@ def corrwith(self, other, axis: Axis = 0, drop=False, method="pearson") -> Serie )[0, 1] return Series(corrs) else: - return this.apply(lambda x: other.corr(x, method=method), axis=axis) + return this.apply(lambda x: other.corr(x, method=method), + axis=axis) other = other._get_numeric_data() left, right = this.align(other, join="inner", copy=False) @@ -9899,16 +9975,14 @@ def corrwith(self, other, axis: Axis = 0, drop=False, method="pearson") -> Serie def c(x): return nanops.nancorr(x[0], x[1], method=method) - correl = self._constructor_sliced( - map(c, zip(left.values.T, right.values.T)), index=left.columns - ) + correl = self._constructor_sliced(map( + c, zip(left.values.T, right.values.T)), + index=left.columns) else: - raise ValueError( - f"Invalid method {method} was passed, " - "valid methods are: 'pearson', 'kendall', " - "'spearman', or callable" - ) + raise ValueError(f"Invalid method {method} was passed, " + "valid methods are: 'pearson', 'kendall', " + "'spearman', or callable") if not drop: # Find non-matching labels along the given axis @@ -9919,17 +9993,17 @@ def c(x): if len(idx_diff) > 0: correl = correl._append( - Series([np.nan] * len(idx_diff), index=idx_diff) - ) + Series([np.nan] * len(idx_diff), index=idx_diff)) return correl # ---------------------------------------------------------------------- # ndarray-like stats methods - def count( - self, axis: Axis = 0, level: Level | None = None, numeric_only: bool = False - ): + def count(self, + axis: Axis = 0, + level: Level | None = None, + numeric_only: bool = False): """ Count non-NA cells for each column or row. @@ -10006,7 +10080,9 @@ def count( FutureWarning, stacklevel=find_stack_level(), ) - res = self._count_level(level, axis=axis, numeric_only=numeric_only) + res = self._count_level(level, + axis=axis, + numeric_only=numeric_only) return res.__finalize__(self, method="count") if numeric_only: @@ -10016,7 +10092,8 @@ def count( # GH #423 if len(frame._get_axis(axis)) == 0: - result = self._constructor_sliced(0, index=frame._get_agg_axis(axis)) + result = self._constructor_sliced(0, + index=frame._get_agg_axis(axis)) else: if frame._is_mixed_type or frame._mgr.any_extension_types: # the or any_extension_types is really only hit for single- @@ -10027,12 +10104,14 @@ def count( series_counts = notna(frame).sum(axis=axis) counts = series_counts.values result = self._constructor_sliced( - counts, index=frame._get_agg_axis(axis) - ) + counts, index=frame._get_agg_axis(axis)) return result.astype("int64").__finalize__(self, method="count") - def _count_level(self, level: Level, axis: int = 0, numeric_only: bool = False): + def _count_level(self, + level: Level, + axis: int = 0, + numeric_only: bool = False): if numeric_only: frame = self._get_numeric_data() else: @@ -10068,12 +10147,19 @@ def _count_level(self, level: Level, axis: int = 0, numeric_only: bool = False): level_name = count_axis._names[level] level_index = count_axis.levels[level]._rename(name=level_name) level_codes = ensure_platform_int(count_axis.codes[level]) - counts = lib.count_level_2d(mask, level_codes, len(level_index), axis=axis) + counts = lib.count_level_2d(mask, + level_codes, + len(level_index), + axis=axis) if axis == 1: - result = self._constructor(counts, index=agg_axis, columns=level_index) + result = self._constructor(counts, + index=agg_axis, + columns=level_index) else: - result = self._constructor(counts, index=level_index, columns=agg_axis) + result = self._constructor(counts, + index=level_index, + columns=agg_axis) return result @@ -10127,8 +10213,7 @@ def func(values: np.ndarray): def blk_func(values, axis=1): if isinstance(values, ExtensionArray): if not is_1d_only_ea_dtype(values.dtype) and not isinstance( - self._mgr, ArrayManager - ): + self._mgr, ArrayManager): return values._reduce(name, axis=1, skipna=skipna, **kwds) return values._reduce(name, skipna=skipna, **kwds) else: @@ -10360,9 +10445,11 @@ def idxmin(self, axis: Axis = 0, skipna: bool = True) -> Series: """ axis = self._get_axis_number(axis) - res = self._reduce( - nanops.nanargmin, "argmin", axis=axis, skipna=skipna, numeric_only=False - ) + res = self._reduce(nanops.nanargmin, + "argmin", + axis=axis, + skipna=skipna, + numeric_only=False) indices = res._values # indices will always be np.ndarray since axis is not None and @@ -10437,9 +10524,11 @@ def idxmax(self, axis: Axis = 0, skipna: bool = True) -> Series: """ axis = self._get_axis_number(axis) - res = self._reduce( - nanops.nanargmax, "argmax", axis=axis, skipna=skipna, numeric_only=False - ) + res = self._reduce(nanops.nanargmax, + "argmax", + axis=axis, + skipna=skipna, + numeric_only=False) indices = res._values # indices will always be np.ndarray since axis is not None and @@ -10462,9 +10551,10 @@ def _get_agg_axis(self, axis_num: int) -> Index: else: raise ValueError(f"Axis must be 0 or 1 (got {repr(axis_num)})") - def mode( - self, axis: Axis = 0, numeric_only: bool = False, dropna: bool = True - ) -> DataFrame: + def mode(self, + axis: Axis = 0, + numeric_only: bool = False, + dropna: bool = True) -> DataFrame: """ Get the mode(s) of each element along the selected axis. @@ -10631,9 +10721,10 @@ def quantile( if not is_list_like(q): # BlockManager.quantile expects listlike, so we wrap and unwrap here - res_df = self.quantile( - [q], axis=axis, numeric_only=numeric_only, interpolation=interpolation - ) + res_df = self.quantile([q], + axis=axis, + numeric_only=numeric_only, + interpolation=interpolation) res = res_df.iloc[0] if axis == 1 and len(self) == 0: # GH#41544 try to get an appropriate dtype @@ -10662,7 +10753,10 @@ def quantile( if is_list_like(q): res = self._constructor([], index=q, columns=cols, dtype=dtype) return res.__finalize__(self, method="quantile") - return self._constructor_sliced([], index=cols, name=q, dtype=dtype) + return self._constructor_sliced([], + index=cols, + name=q, + dtype=dtype) res = data._mgr.quantile(qs=q, axis=1, interpolation=interpolation) @@ -10755,9 +10849,10 @@ def to_timestamp( setattr(new_obj, axis_name, new_ax) return new_obj - def to_period( - self, freq: Frequency | None = None, axis: Axis = 0, copy: bool = True - ) -> DataFrame: + def to_period(self, + freq: Frequency | None = None, + axis: Axis = 0, + copy: bool = True) -> DataFrame: """ Convert DataFrame from DatetimeIndex to PeriodIndex. @@ -10885,10 +10980,8 @@ def isin(self, values) -> DataFrame: values = collections.defaultdict(list, values) result = concat( - ( - self.iloc[:, [i]].isin(values[col]) - for i, col in enumerate(self.columns) - ), + (self.iloc[:, [i]].isin(values[col]) + for i, col in enumerate(self.columns)), axis=1, ) elif isinstance(values, Series): @@ -10904,10 +10997,10 @@ def isin(self, values) -> DataFrame: raise TypeError( "only list-like or dict-like objects are allowed " "to be passed to DataFrame.isin(), " - f"you passed a '{type(values).__name__}'" - ) + f"you passed a '{type(values).__name__}'") result = self._constructor( - algorithms.isin(self.values.ravel(), values).reshape(self.shape), + algorithms.isin(self.values.ravel(), + values).reshape(self.shape), self.index, self.columns, ) @@ -10926,11 +11019,9 @@ def isin(self, values) -> DataFrame: _info_axis_name = "columns" index: Index = properties.AxisProperty( - axis=1, doc="The index (row labels) of the DataFrame." - ) + axis=1, doc="The index (row labels) of the DataFrame.") columns: Index = properties.AxisProperty( - axis=0, doc="The column labels of the DataFrame." - ) + axis=0, doc="The column labels of the DataFrame.") @property def _AXIS_NUMBERS(self) -> dict[str, int]: @@ -11067,9 +11158,8 @@ def bfill( ) -> DataFrame | None: return super().bfill(axis, inplace, limit, downcast) - @deprecate_nonkeyword_arguments( - version=None, allowed_args=["self", "lower", "upper"] - ) + @deprecate_nonkeyword_arguments(version=None, + allowed_args=["self", "lower", "upper"]) def clip( self: DataFrame, lower=None, @@ -11081,7 +11171,8 @@ def clip( ) -> DataFrame | None: return super().clip(lower, upper, axis, inplace, *args, **kwargs) - @deprecate_nonkeyword_arguments(version=None, allowed_args=["self", "method"]) + @deprecate_nonkeyword_arguments(version=None, + allowed_args=["self", "method"]) def interpolate( self: DataFrame, method: str = "linear", @@ -11104,9 +11195,8 @@ def interpolate( **kwargs, ) - @deprecate_nonkeyword_arguments( - version=None, allowed_args=["self", "cond", "other"] - ) + @deprecate_nonkeyword_arguments(version=None, + allowed_args=["self", "cond", "other"]) def where( self, cond, @@ -11117,11 +11207,11 @@ def where( errors="raise", try_cast=lib.no_default, ): - return super().where(cond, other, inplace, axis, level, errors, try_cast) + return super().where(cond, other, inplace, axis, level, errors, + try_cast) - @deprecate_nonkeyword_arguments( - version=None, allowed_args=["self", "cond", "other"] - ) + @deprecate_nonkeyword_arguments(version=None, + allowed_args=["self", "cond", "other"]) def mask( self, cond, @@ -11132,7 +11222,8 @@ def mask( errors="raise", try_cast=lib.no_default, ): - return super().mask(cond, other, inplace, axis, level, errors, try_cast) + return super().mask(cond, other, inplace, axis, level, errors, + try_cast) DataFrame._add_numeric_operations() @@ -11164,6 +11255,5 @@ def _reindex_for_setitem(value: DataFrame | Series, index: Index) -> ArrayLike: raise err raise TypeError( - "incompatible index of inserted column with frame index" - ) from err + "incompatible index of inserted column with frame index") from err return reindexed_value diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index eeeeb41d38d6a..bd97030ec81b0 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -184,7 +184,6 @@ ) from pandas.core.arrays import PeriodArray - __all__ = ["Index"] _unsortable_types = frozenset(("mixed", "mixed-integer")) @@ -200,7 +199,6 @@ _index_shared_docs: dict[str, str] = {} str_t = str - _dtype_obj = np.dtype("object") @@ -218,7 +216,11 @@ def join( return_indexers: bool = False, sort: bool = False, ): - join_index, lidx, ridx = meth(self, other, how=how, level=level, sort=sort) + join_index, lidx, ridx = meth(self, + other, + how=how, + level=level, + sort=sort) if not return_indexers: return join_index @@ -310,11 +312,9 @@ class Index(IndexOpsMixin, PandasObject): """ # tolist is not actually deprecated, just suppressed in the __dir__ - _hidden_attrs: frozenset[str] = ( - PandasObject._hidden_attrs - | IndexOpsMixin._hidden_attrs - | frozenset(["contains", "set_value"]) - ) + _hidden_attrs: frozenset[str] = (PandasObject._hidden_attrs + | IndexOpsMixin._hidden_attrs + | frozenset(["contains", "set_value"])) # To hand over control to subclasses _join_precedence = 1 @@ -325,7 +325,8 @@ class Index(IndexOpsMixin, PandasObject): # given the dtypes of the passed arguments @final - def _left_indexer_unique(self: _IndexT, other: _IndexT) -> npt.NDArray[np.intp]: + def _left_indexer_unique(self: _IndexT, + other: _IndexT) -> npt.NDArray[np.intp]: # Caller is responsible for ensuring other.dtype == self.dtype sv = self._get_engine_target() ov = other._get_engine_target() @@ -378,9 +379,10 @@ def _outer_indexer( _typ: str = "index" _data: ExtensionArray | np.ndarray - _data_cls: type[ExtensionArray] | tuple[type[np.ndarray], type[ExtensionArray]] = ( - np.ndarray, - ExtensionArray, + _data_cls: type[ExtensionArray] | tuple[type[np.ndarray], + type[ExtensionArray]] = ( + np.ndarray, + ExtensionArray, ) _id: object | None = None _name: Hashable = None @@ -399,8 +401,7 @@ def _outer_indexer( _is_backward_compat_public_numeric_index: bool = False _engine_type: type[libindex.IndexEngine] | type[ - libindex.ExtensionEngine - ] = libindex.ObjectEngine + libindex.ExtensionEngine] = libindex.ObjectEngine # whether we support partial string indexing. Overridden # in DatetimeIndex and PeriodIndex _supports_partial_string_indexing = False @@ -412,9 +413,13 @@ def _outer_indexer( # -------------------------------------------------------------------- # Constructors - def __new__( - cls, data=None, dtype=None, copy=False, name=None, tupleize_cols=True, **kwargs - ) -> Index: + def __new__(cls, + data=None, + dtype=None, + copy=False, + name=None, + tupleize_cols=True, + **kwargs) -> Index: if kwargs: warnings.warn( @@ -481,11 +486,9 @@ def __new__( return Index._simple_new(data, name=name) # index-like - elif ( - isinstance(data, Index) - and data._is_backward_compat_public_numeric_index - and dtype is None - ): + elif (isinstance(data, Index) + and data._is_backward_compat_public_numeric_index + and dtype is None): return data._constructor(data, name=name, copy=copy) elif isinstance(data, (np.ndarray, Index, ABCSeries)): @@ -514,8 +517,7 @@ def __new__( if dtype is None: arr = _maybe_cast_data_without_dtype( - arr, cast_numeric_deprecated=True - ) + arr, cast_numeric_deprecated=True) dtype = arr.dtype if kwargs: @@ -529,7 +531,11 @@ def __new__( elif is_scalar(data): raise cls._scalar_data_error(data) elif hasattr(data, "__array__"): - return Index(np.asarray(data), dtype=dtype, copy=copy, name=name, **kwargs) + return Index(np.asarray(data), + dtype=dtype, + copy=copy, + name=name, + **kwargs) else: if tupleize_cols and is_list_like(data): @@ -542,9 +548,9 @@ def __new__( # 10697 from pandas.core.indexes.multi import MultiIndex - return MultiIndex.from_tuples( - data, names=name or kwargs.get("names") - ) + return MultiIndex.from_tuples(data, + names=name + or kwargs.get("names")) # other iterable of some kind subarr = com.asarray_tuplesafe(data, dtype=_dtype_obj) @@ -554,8 +560,8 @@ def __new__( # "Union[ExtensionArray, ndarray[Any, Any]]", variable has type # "ndarray[Any, Any]") subarr = _maybe_cast_data_without_dtype( # type: ignore[assignment] - subarr, cast_numeric_deprecated=False - ) + subarr, + cast_numeric_deprecated=False) dtype = subarr.dtype return Index(subarr, dtype=dtype, copy=copy, name=name, **kwargs) @@ -637,8 +643,8 @@ def _dtype_to_subclass(cls, dtype: DtypeObj): return Index elif issubclass( - dtype.type, (str, bool, np.bool_, complex, np.complex64, np.complex128) - ): + dtype.type, + (str, bool, np.bool_, complex, np.complex64, np.complex128)): return Index raise NotImplementedError(dtype) @@ -676,7 +682,9 @@ def asi8(self): return None @classmethod - def _simple_new(cls: type[_IndexT], values, name: Hashable = None) -> _IndexT: + def _simple_new(cls: type[_IndexT], + values, + name: Hashable = None) -> _IndexT: """ We require that we have a dtype compat for the values. If we are passed a non-dtype compat, then coerce using the constructor. @@ -700,14 +708,16 @@ def _with_infer(cls, *args, **kwargs): for ndarray[object] inputs. """ with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*the Index constructor", FutureWarning) + warnings.filterwarnings("ignore", ".*the Index constructor", + FutureWarning) result = cls(*args, **kwargs) if result.dtype == _dtype_obj and not result._is_multi: # error: Argument 1 to "maybe_convert_objects" has incompatible type # "Union[ExtensionArray, ndarray[Any, Any]]"; expected # "ndarray[Any, Any]" - values = lib.maybe_convert_objects(result._values) # type: ignore[arg-type] + values = lib.maybe_convert_objects( + result._values) # type: ignore[arg-type] if values.dtype.kind in ["i", "u", "f", "b"]: return Index(values, name=result.name) @@ -764,7 +774,8 @@ def _format_duplicate_message(self) -> DataFrame: if self._is_multi: # test_format_duplicate_labels_message_multi # error: "Type[Index]" has no attribute "from_tuples" [attr-defined] - out.index = type(self).from_tuples(out.index) # type: ignore[attr-defined] + out.index = type(self).from_tuples( + out.index) # type: ignore[attr-defined] if self.nlevels == 1: out = out.rename_axis("label") @@ -789,7 +800,9 @@ def _get_attributes_dict(self) -> dict[str_t, Any]: ) return {k: getattr(self, k, None) for k in self._attributes} - def _shallow_copy(self: _IndexT, values, name: Hashable = no_default) -> _IndexT: + def _shallow_copy(self: _IndexT, + values, + name: Hashable = no_default) -> _IndexT: """ Create a new Index with the same class as the caller, don't copy the data, use the same object attributes with passed in attributes taking @@ -867,15 +880,11 @@ def _cleanup(self) -> None: self._engine.clear_mapping() @cache_readonly - def _engine( - self, - ) -> libindex.IndexEngine | libindex.ExtensionEngine: + def _engine(self, ) -> libindex.IndexEngine | libindex.ExtensionEngine: # For base class (object dtype) we get ObjectEngine target_values = self._get_engine_target() - if ( - isinstance(target_values, ExtensionArray) - and self._engine_type is libindex.ObjectEngine - ): + if (isinstance(target_values, ExtensionArray) + and self._engine_type is libindex.ObjectEngine): return libindex.ExtensionEngine(target_values) target_values = cast(np.ndarray, target_values) @@ -902,7 +911,7 @@ def _dir_additions_for_owner(self) -> set[str_t]: """ return { c - for c in self.unique(level=0)[: get_option("display.max_dir_items")] + for c in self.unique(level=0)[:get_option("display.max_dir_items")] if isinstance(c, str) and c.isidentifier() } @@ -922,26 +931,26 @@ def __array__(self, dtype=None) -> np.ndarray: """ return np.asarray(self._data, dtype=dtype) - def __array_ufunc__(self, ufunc: np.ufunc, method: str_t, *inputs, **kwargs): - if any(isinstance(other, (ABCSeries, ABCDataFrame)) for other in inputs): + def __array_ufunc__(self, ufunc: np.ufunc, method: str_t, *inputs, + **kwargs): + if any( + isinstance(other, (ABCSeries, ABCDataFrame)) + for other in inputs): return NotImplemented result = arraylike.maybe_dispatch_ufunc_to_dunder_op( - self, ufunc, method, *inputs, **kwargs - ) + self, ufunc, method, *inputs, **kwargs) if result is not NotImplemented: return result if "out" in kwargs: # e.g. test_dti_isub_tdi - return arraylike.dispatch_ufunc_with_out( - self, ufunc, method, *inputs, **kwargs - ) + return arraylike.dispatch_ufunc_with_out(self, ufunc, method, + *inputs, **kwargs) if method == "reduce": result = arraylike.dispatch_reduction_ufunc( - self, ufunc, method, *inputs, **kwargs - ) + self, ufunc, method, *inputs, **kwargs) if result is not NotImplemented: return result @@ -958,7 +967,8 @@ def __array_wrap__(self, result, context=None): Gets called after a ufunc and other functions e.g. np.split. """ result = lib.item_from_zerodim(result) - if is_bool_dtype(result) or lib.is_scalar(result) or np.ndim(result) > 1: + if is_bool_dtype(result) or lib.is_scalar( + result) or np.ndim(result) > 1: return result return Index(result, name=self.name) @@ -1009,9 +1019,9 @@ def view(self, cls=None): if isinstance(cls, str): dtype = pandas_dtype(cls) - if isinstance(dtype, (np.dtype, ExtensionDtype)) and needs_i8_conversion( - dtype - ): + if isinstance( + dtype, + (np.dtype, ExtensionDtype)) and needs_i8_conversion(dtype): if dtype.kind == "m" and dtype != "m8[ns]": # e.g. m8[s] return self._data.view(cls) @@ -1062,12 +1072,8 @@ def astype(self, dtype, copy: bool = True): # Ensure that self.astype(self.dtype) is self return self.copy() if copy else self - if ( - self.dtype == np.dtype("M8[ns]") - and isinstance(dtype, np.dtype) - and dtype.kind == "M" - and dtype != np.dtype("M8[ns]") - ): + if (self.dtype == np.dtype("M8[ns]") and isinstance(dtype, np.dtype) + and dtype.kind == "M" and dtype != np.dtype("M8[ns]")): # For now DatetimeArray supports this by unwrapping ndarray, # but DatetimeIndex doesn't raise TypeError(f"Cannot cast {type(self).__name__} to dtype") @@ -1082,8 +1088,7 @@ def astype(self, dtype, copy: bool = True): # TODO: this differs from Series behavior; can/should we align them? raise TypeError( f"Cannot convert Float64Index to dtype {dtype}; integer " - "values are required for conversion" - ) + "values are required for conversion") elif isinstance(dtype, ExtensionDtype): cls = dtype.construct_array_type() @@ -1115,14 +1120,16 @@ def astype(self, dtype, copy: bool = True): # NumericIndex[int32] and not Int64Index with dtype int64. # When Int64Index etc. are removed from the code base, removed this also. if isinstance(dtype, np.dtype) and is_numeric_dtype(dtype): - return self._constructor( - new_values, name=self.name, dtype=dtype, copy=False - ) - return Index(new_values, name=self.name, dtype=new_values.dtype, copy=False) - - _index_shared_docs[ - "take" - ] = """ + return self._constructor(new_values, + name=self.name, + dtype=dtype, + copy=False) + return Index(new_values, + name=self.name, + dtype=new_values.dtype, + copy=False) + + _index_shared_docs["take"] = """ Return a new %(klass)s of the values selected by the indices. For internal compatibility with numpy arrays. @@ -1151,9 +1158,12 @@ def astype(self, dtype, copy: bool = True): """ @Appender(_index_shared_docs["take"] % _index_doc_kwargs) - def take( - self, indices, axis: int = 0, allow_fill: bool = True, fill_value=None, **kwargs - ): + def take(self, + indices, + axis: int = 0, + allow_fill: bool = True, + fill_value=None, + **kwargs): if kwargs: nv.validate_take((), kwargs) if is_scalar(indices): @@ -1165,19 +1175,21 @@ def take( # in the case where allow_fill is True and fill_value is not None values = self._values if isinstance(values, np.ndarray): - taken = algos.take( - values, indices, allow_fill=allow_fill, fill_value=self._na_value - ) + taken = algos.take(values, + indices, + allow_fill=allow_fill, + fill_value=self._na_value) else: # algos.take passes 'axis' keyword which not all EAs accept - taken = values.take( - indices, allow_fill=allow_fill, fill_value=self._na_value - ) + taken = values.take(indices, + allow_fill=allow_fill, + fill_value=self._na_value) # _constructor so RangeIndex->Int64Index return self._constructor._simple_new(taken, name=self.name) @final - def _maybe_disallow_fill(self, allow_fill: bool, fill_value, indices) -> bool: + def _maybe_disallow_fill(self, allow_fill: bool, fill_value, + indices) -> bool: """ We only use pandas-style take when allow_fill is True _and_ fill_value is not None. @@ -1188,8 +1200,7 @@ def _maybe_disallow_fill(self, allow_fill: bool, fill_value, indices) -> bool: if (indices < -1).any(): raise ValueError( "When allow_fill=True and fill_value is not None, " - "all indices must be >= -1" - ) + "all indices must be >= -1") else: cls_name = type(self).__name__ raise ValueError( @@ -1199,9 +1210,7 @@ def _maybe_disallow_fill(self, allow_fill: bool, fill_value, indices) -> bool: allow_fill = False return allow_fill - _index_shared_docs[ - "repeat" - ] = """ + _index_shared_docs["repeat"] = """ Repeat elements of a %(klass)s. Returns a new %(klass)s where each element of the current %(klass)s @@ -1413,7 +1422,8 @@ def _get_level_names(self) -> Hashable | Sequence[Hashable]: """ if self._is_multi: return [ - level if name is None else name for level, name in enumerate(self.names) + level if name is None else name + for level, name in enumerate(self.names) ] else: return 0 if self.name is None else self.name @@ -1437,17 +1447,16 @@ def format( header = [] if name: header.append( - pprint_thing(self.name, escape_chars=("\t", "\r", "\n")) - if self.name is not None - else "" - ) + pprint_thing(self.name, escape_chars=( + "\t", "\r", "\n")) if self.name is not None else "") if formatter is not None: return header + list(self.map(formatter)) return self._format_with_header(header, na_rep=na_rep) - def _format_with_header(self, header: list[str_t], na_rep: str_t) -> list[str_t]: + def _format_with_header(self, header: list[str_t], + na_rep: str_t) -> list[str_t]: from pandas.io.formats.format import format_array values = self._values @@ -1456,7 +1465,10 @@ def _format_with_header(self, header: list[str_t], na_rep: str_t) -> list[str_t] values = cast(np.ndarray, values) values = lib.maybe_convert_objects(values, safe=True) - result = [pprint_thing(x, escape_chars=("\t", "\r", "\n")) for x in values] + result = [ + pprint_thing(x, escape_chars=("\t", "\r", "\n")) + for x in values + ] # could have nans mask = is_float_nan(values) @@ -1639,9 +1651,9 @@ def to_series(self, index=None, name: Hashable = None) -> Series: return Series(self._values.copy(), index=index, name=name) - def to_frame( - self, index: bool = True, name: Hashable = lib.no_default - ) -> DataFrame: + def to_frame(self, + index: bool = True, + name: Hashable = lib.no_default) -> DataFrame: """ Create a DataFrame with a column containing the Index. @@ -1727,15 +1739,15 @@ def name(self, value: Hashable): # Used in MultiIndex.levels to avoid silently ignoring name updates. raise RuntimeError( "Cannot set name on a level of a MultiIndex. Use " - "'MultiIndex.set_names' instead." - ) + "'MultiIndex.set_names' instead.") maybe_extract_name(value, None, type(self)) self._name = value @final - def _validate_names( - self, name=None, names=None, deep: bool = False - ) -> list[Hashable]: + def _validate_names(self, + name=None, + names=None, + deep: bool = False) -> list[Hashable]: """ Handles the quirks of having a singular 'name' parameter for general Index and plural 'names' parameter for MultiIndex. @@ -1761,12 +1773,13 @@ def _validate_names( ) # All items in 'new_names' need to be hashable - validate_all_hashable(*new_names, error_name=f"{type(self).__name__}.name") + validate_all_hashable(*new_names, + error_name=f"{type(self).__name__}.name") return new_names def _get_names(self) -> FrozenList: - return FrozenList((self.name,)) + return FrozenList((self.name, )) def _set_names(self, values, *, level=None) -> None: """ @@ -1787,17 +1800,20 @@ def _set_names(self, values, *, level=None) -> None: if not is_list_like(values): raise ValueError("Names must be a list-like") if len(values) != 1: - raise ValueError(f"Length of new names must be 1, got {len(values)}") + raise ValueError( + f"Length of new names must be 1, got {len(values)}") # GH 20527 # All items in 'name' need to be hashable: - validate_all_hashable(*values, error_name=f"{type(self).__name__}.name") + validate_all_hashable(*values, + error_name=f"{type(self).__name__}.name") self._name = values[0] names = property(fset=_set_names, fget=_get_names) - @deprecate_nonkeyword_arguments(version=None, allowed_args=["self", "names"]) + @deprecate_nonkeyword_arguments(version=None, + allowed_args=["self", "names"]) def set_names(self, names, level=None, inplace: bool = False): """ Set Index or MultiIndex name. @@ -1873,19 +1889,23 @@ def set_names(self, names, level=None, inplace: bool = False): if level is not None and not isinstance(self, ABCMultiIndex): raise ValueError("Level must be None for non-MultiIndex") - elif level is not None and not is_list_like(level) and is_list_like(names): - raise TypeError("Names must be a string when a single level is provided.") + elif level is not None and not is_list_like(level) and is_list_like( + names): + raise TypeError( + "Names must be a string when a single level is provided.") elif not is_list_like(names) and level is None and self.nlevels > 1: raise TypeError("Must pass list-like as `names`.") elif is_dict_like(names) and not isinstance(self, ABCMultiIndex): - raise TypeError("Can only pass dict-like as `names` for MultiIndex.") + raise TypeError( + "Can only pass dict-like as `names` for MultiIndex.") elif is_dict_like(names) and level is not None: raise TypeError("Can not pass level for dictlike `names`.") - if isinstance(self, ABCMultiIndex) and is_dict_like(names) and level is None: + if isinstance(self, + ABCMultiIndex) and is_dict_like(names) and level is None: # Transform dict to list of new names and corresponding levels level, names_adjusted = [], [] for i, name in enumerate(self.names): @@ -1986,10 +2006,8 @@ def _validate_index_level(self, level) -> None: """ if isinstance(level, int): if level < 0 and level != -1: - raise IndexError( - "Too many levels: Index has only 1 level, " - f"{level} is not a valid level number" - ) + raise IndexError("Too many levels: Index has only 1 level, " + f"{level} is not a valid level number") elif level > 0: raise IndexError( f"Too many levels: Index has only 1 level, not {level + 1}" @@ -2021,14 +2039,13 @@ def sortlevel(self, level=None, ascending=True, sort_remaining=None): Index """ if not isinstance(ascending, (list, bool)): - raise TypeError( - "ascending must be a single bool value or" - "a list of bool values of length 1" - ) + raise TypeError("ascending must be a single bool value or" + "a list of bool values of length 1") if isinstance(ascending, list): if len(ascending) != 1: - raise TypeError("ascending must be a list of bool values of length 1") + raise TypeError( + "ascending must be a list of bool values of length 1") ascending = ascending[0] if not isinstance(ascending, bool): @@ -2140,8 +2157,7 @@ def _drop_level_numbers(self, levnums: list[int]): if len(levnums) >= self.nlevels: raise ValueError( f"Cannot remove {len(levnums)} levels from an index with " - f"{self.nlevels} levels: at least one level must be left." - ) + f"{self.nlevels} levels: at least one level must be left.") # The two checks above guarantee that here self is a MultiIndex self = cast("MultiIndex", self) @@ -2164,9 +2180,12 @@ def _drop_level_numbers(self, levnums: list[int]): # see test_reset_index_empty_rangeindex result = lev[:0] else: - res_values = algos.take(lev._values, new_codes[0], allow_fill=True) + res_values = algos.take(lev._values, + new_codes[0], + allow_fill=True) # _constructor instead of type(lev) for RangeIndex compat GH#35230 - result = lev._constructor._simple_new(res_values, name=new_names[0]) + result = lev._constructor._simple_new(res_values, + name=new_names[0]) else: # set nan if needed mask = new_codes[0] == -1 @@ -2188,7 +2207,10 @@ def _drop_level_numbers(self, levnums: list[int]): ) def _get_grouper_for_level( - self, mapper, *, level=None + self, + mapper, + *, + level=None ) -> tuple[Index, npt.NDArray[np.signedinteger] | None, Index | None]: """ Get index grouper corresponding to an index level @@ -2472,7 +2494,9 @@ def is_floating(self) -> bool: >>> idx.is_floating() False """ - return self.inferred_type in ["floating", "mixed-integer-float", "integer-na"] + return self.inferred_type in [ + "floating", "mixed-integer-float", "integer-na" + ] @final def is_numeric(self) -> bool: @@ -2911,8 +2935,7 @@ def fillna(self, value=None, downcast=None): return Index._with_infer(result, name=self.name) raise NotImplementedError( f"{type(self).__name__}.fillna does not support 'downcast' " - "argument values other than 'None'." - ) + "argument values other than 'None'.") return self._view() def dropna(self: _IndexT, how: str_t = "any") -> _IndexT: @@ -2971,7 +2994,8 @@ def unique(self: _IndexT, level: Hashable | None = None) -> _IndexT: return self._shallow_copy(result) @deprecate_nonkeyword_arguments(version=None, allowed_args=["self"]) - def drop_duplicates(self: _IndexT, keep: str_t | bool = "first") -> _IndexT: + def drop_duplicates(self: _IndexT, + keep: str_t | bool = "first") -> _IndexT: """ Return Index with duplicate values removed. @@ -3023,8 +3047,9 @@ def drop_duplicates(self: _IndexT, keep: str_t | bool = "first") -> _IndexT: return super().drop_duplicates(keep=keep) def duplicated( - self, keep: Literal["first", "last", False] = "first" - ) -> npt.NDArray[np.bool_]: + self, + keep: Literal["first", "last", + False] = "first") -> npt.NDArray[np.bool_]: """ Indicate duplicate index values. @@ -3127,8 +3152,7 @@ def __xor__(self, other): def __nonzero__(self): raise ValueError( f"The truth value of a {type(self).__name__} is ambiguous. " - "Use a.empty, a.bool(), a.item(), a.any() or a.all()." - ) + "Use a.empty, a.bool(), a.item(), a.any() or a.all().") __bool__ = __nonzero__ @@ -3149,10 +3173,8 @@ def _get_reconciled_name_object(self, other): @final def _validate_sort_keyword(self, sort): if sort not in [None, False]: - raise ValueError( - "The 'sort' keyword only takes the values of " - f"None or False; {sort} was passed." - ) + raise ValueError("The 'sort' keyword only takes the values of " + f"None or False; {sort} was passed.") @final def _deprecate_dti_setop(self, other: Index, setop: str_t): @@ -3162,12 +3184,9 @@ def _deprecate_dti_setop(self, other: Index, setop: str_t): """ # Caller is responsibelf or checking # `not is_dtype_equal(self.dtype, other.dtype)` - if ( - isinstance(self, ABCDatetimeIndex) - and isinstance(other, ABCDatetimeIndex) - and self.tz is not None - and other.tz is not None - ): + if (isinstance(self, ABCDatetimeIndex) + and isinstance(other, ABCDatetimeIndex) and self.tz is not None + and other.tz is not None): # GH#39328, GH#45357 warnings.warn( f"In a future version, the {setop} of DatetimeIndex objects " @@ -3267,15 +3286,12 @@ def union(self, other, sort=None): other, result_name = self._convert_can_do_setop(other) if not is_dtype_equal(self.dtype, other.dtype): - if ( - isinstance(self, ABCMultiIndex) - and not is_object_dtype(unpack_nested_dtype(other)) - and len(other) > 0 - ): + if (isinstance(self, ABCMultiIndex) + and not is_object_dtype(unpack_nested_dtype(other)) + and len(other) > 0): raise NotImplementedError( "Can only union MultiIndex with MultiIndex or Index of tuples, " - "try mi.to_flat_index().union(other) instead." - ) + "try mi.to_flat_index().union(other) instead.") self._deprecate_dti_setop(other, "union") dtype = self._find_common_type_compat(other) @@ -3317,13 +3333,10 @@ def _union(self, other: Index, sort): lvals = self._values rvals = other._values - if ( - sort is None - and self.is_monotonic_increasing - and other.is_monotonic_increasing - and not (self.has_duplicates and other.has_duplicates) - and self._can_use_libjoin - ): + if (sort is None and self.is_monotonic_increasing + and other.is_monotonic_increasing + and not (self.has_duplicates and other.has_duplicates) + and self._can_use_libjoin): # Both are monotonic and at least one is unique, so can use outer join # (actually don't need either unique, but without this restriction # test_union_same_value_duplicated_in_both fails) @@ -3457,11 +3470,8 @@ def _intersection(self, other: Index, sort=False): """ intersection specialized to the case with matching dtypes. """ - if ( - self.is_monotonic_increasing - and other.is_monotonic_increasing - and self._can_use_libjoin - ): + if (self.is_monotonic_increasing and other.is_monotonic_increasing + and self._can_use_libjoin): try: result = self._inner_indexer(other)[0] except TypeError: @@ -3570,7 +3580,9 @@ def _difference(self, other, sort): indexer = this.get_indexer_for(other) indexer = indexer.take((indexer != -1).nonzero()[0]) - label_diff = np.setdiff1d(np.arange(this.size), indexer, assume_unique=True) + label_diff = np.setdiff1d(np.arange(this.size), + indexer, + assume_unique=True) the_diff = this._values.take(label_diff) the_diff = _maybe_try_sort(the_diff, sort) @@ -3631,7 +3643,8 @@ def symmetric_difference(self, other, result_name=None, sort=None): dtype = self._find_common_type_compat(other) this = self.astype(dtype, copy=False) that = other.astype(dtype, copy=False) - return this.symmetric_difference(that, sort=sort).rename(result_name) + return this.symmetric_difference(that, + sort=sort).rename(result_name) this = self.unique() other = other.unique() @@ -3639,9 +3652,9 @@ def symmetric_difference(self, other, result_name=None, sort=None): # {this} minus {other} common_indexer = indexer.take((indexer != -1).nonzero()[0]) - left_indexer = np.setdiff1d( - np.arange(this.size), common_indexer, assume_unique=True - ) + left_indexer = np.setdiff1d(np.arange(this.size), + common_indexer, + assume_unique=True) left_diff = this._values.take(left_indexer) # {other} minus {this} @@ -3729,10 +3742,8 @@ def get_loc(self, key, method=None, tolerance=None): """ if method is None: if tolerance is not None: - raise ValueError( - "tolerance argument only valid if using pad, " - "backfill or nearest lookups" - ) + raise ValueError("tolerance argument only valid if using pad, " + "backfill or nearest lookups") casted_key = self._maybe_cast_indexer(key) try: return self._engine.get_loc(casted_key) @@ -3768,9 +3779,7 @@ def get_loc(self, key, method=None, tolerance=None): raise KeyError(key) return loc - _index_shared_docs[ - "get_indexer" - ] = """ + _index_shared_docs["get_indexer"] = """ Compute indexer and mask for new index given the current index. The indexer should be then used as an input to ndarray.take to align the current data to the new index. @@ -3841,10 +3850,13 @@ def get_indexer( if len(target) == 0: return np.array([], dtype=np.intp) - if not self._should_compare(target) and not self._should_partial_index(target): + if not self._should_compare(target) and not self._should_partial_index( + target): # IntervalIndex get special treatment bc numeric scalars can be # matched to Interval scalars - return self._get_indexer_non_comparable(target, method=method, unique=True) + return self._get_indexer_non_comparable(target, + method=method, + unique=True) if is_categorical_dtype(self.dtype): # _maybe_cast_listlike_indexer ensures target has our dtype @@ -3871,7 +3883,9 @@ def get_indexer( # e.g. test_append_different_columns_types categories_indexer = self.get_indexer(target.categories) - indexer = algos.take_nd(categories_indexer, target.codes, fill_value=-1) + indexer = algos.take_nd(categories_indexer, + target.codes, + fill_value=-1) if (not self._is_multi and self.hasnans) and target.hasnans: # Exclude MultiIndex because hasnans raises NotImplementedError @@ -3885,25 +3899,27 @@ def get_indexer( pself, ptarget = self._maybe_promote(target) if pself is not self or ptarget is not target: - return pself.get_indexer( - ptarget, method=method, limit=limit, tolerance=tolerance - ) + return pself.get_indexer(ptarget, + method=method, + limit=limit, + tolerance=tolerance) if is_dtype_equal(self.dtype, target.dtype) and self.equals(target): # Only call equals if we have same dtype to avoid inference/casting return np.arange(len(target), dtype=np.intp) - if not is_dtype_equal(self.dtype, target.dtype) and not is_interval_dtype( - self.dtype - ): + if not is_dtype_equal(self.dtype, + target.dtype) and not is_interval_dtype( + self.dtype): # IntervalIndex gets special treatment for partial-indexing dtype = self._find_common_type_compat(target) this = self.astype(dtype, copy=False) target = target.astype(dtype, copy=False) - return this._get_indexer( - target, method=method, limit=limit, tolerance=tolerance - ) + return this._get_indexer(target, + method=method, + limit=limit, + tolerance=tolerance) return self._get_indexer(target, method, limit, tolerance) @@ -3927,8 +3943,7 @@ def _get_indexer( # error: Item "IndexEngine" of "Union[IndexEngine, ExtensionEngine]" # has no attribute "_extract_level_codes" tgt_values = engine._extract_level_codes( # type: ignore[union-attr] - target - ) + target) else: tgt_values = target._get_engine_target() @@ -3943,7 +3958,8 @@ def _should_partial_index(self, target: Index) -> bool: """ if is_interval_dtype(self.dtype): # "Index" has no attribute "left" - return self.left._should_compare(target) # type: ignore[attr-defined] + return self.left._should_compare( + target) # type: ignore[attr-defined] return False @final @@ -3956,7 +3972,9 @@ def _check_indexing_method( """ Raise if we have a get_indexer `method` that is not supported or valid. """ - if method not in [None, "bfill", "backfill", "pad", "ffill", "nearest"]: + if method not in [ + None, "bfill", "backfill", "pad", "ffill", "nearest" + ]: # in practice the clean_reindex_fill_method call would raise # before we get here raise ValueError("Invalid fill method") # pragma: no cover @@ -3965,13 +3983,11 @@ def _check_indexing_method( if method == "nearest": raise NotImplementedError( "method='nearest' not implemented yet " - "for MultiIndex; see GitHub issue 9365" - ) + "for MultiIndex; see GitHub issue 9365") elif method in ("pad", "backfill"): if tolerance is not None: raise NotImplementedError( - "tolerance not implemented yet for MultiIndex" - ) + "tolerance not implemented yet for MultiIndex") if is_interval_dtype(self.dtype) or is_categorical_dtype(self.dtype): # GH#37871 for now this is only for IntervalIndex and CategoricalIndex @@ -3982,27 +3998,27 @@ def _check_indexing_method( if method is None: if tolerance is not None: - raise ValueError( - "tolerance argument only valid if doing pad, " - "backfill or nearest reindexing" - ) + raise ValueError("tolerance argument only valid if doing pad, " + "backfill or nearest reindexing") if limit is not None: - raise ValueError( - "limit argument only valid if doing pad, " - "backfill or nearest reindexing" - ) + raise ValueError("limit argument only valid if doing pad, " + "backfill or nearest reindexing") - def _convert_tolerance(self, tolerance, target: np.ndarray | Index) -> np.ndarray: + def _convert_tolerance(self, tolerance, + target: np.ndarray | Index) -> np.ndarray: # override this method on subclasses tolerance = np.asarray(tolerance) if target.size != tolerance.size and tolerance.size > 1: - raise ValueError("list-like tolerance size must match target index size") + raise ValueError( + "list-like tolerance size must match target index size") return tolerance @final - def _get_fill_indexer( - self, target: Index, method: str_t, limit: int | None = None, tolerance=None - ) -> npt.NDArray[np.intp]: + def _get_fill_indexer(self, + target: Index, + method: str_t, + limit: int | None = None, + tolerance=None) -> npt.NDArray[np.intp]: if self._is_multi: # TODO: get_indexer_with_fill docstring says values must be _sorted_ @@ -4010,32 +4026,39 @@ def _get_fill_indexer( # error: "IndexEngine" has no attribute "get_indexer_with_fill" engine = self._engine return engine.get_indexer_with_fill( # type: ignore[union-attr] - target=target._values, values=self._values, method=method, limit=limit - ) + target=target._values, + values=self._values, + method=method, + limit=limit) if self.is_monotonic_increasing and target.is_monotonic_increasing: target_values = target._get_engine_target() own_values = self._get_engine_target() if not isinstance(target_values, np.ndarray) or not isinstance( - own_values, np.ndarray - ): + own_values, np.ndarray): raise NotImplementedError if method == "pad": indexer = libalgos.pad(own_values, target_values, limit=limit) else: # i.e. "backfill" - indexer = libalgos.backfill(own_values, target_values, limit=limit) + indexer = libalgos.backfill(own_values, + target_values, + limit=limit) else: - indexer = self._get_fill_indexer_searchsorted(target, method, limit) + indexer = self._get_fill_indexer_searchsorted( + target, method, limit) if tolerance is not None and len(self): - indexer = self._filter_indexer_tolerance(target, indexer, tolerance) + indexer = self._filter_indexer_tolerance(target, indexer, + tolerance) return indexer @final def _get_fill_indexer_searchsorted( - self, target: Index, method: str_t, limit: int | None = None - ) -> npt.NDArray[np.intp]: + self, + target: Index, + method: str_t, + limit: int | None = None) -> npt.NDArray[np.intp]: """ Fallback pad/backfill get_indexer that works for monotonic decreasing indexes and non-monotonic targets. @@ -4043,15 +4066,15 @@ def _get_fill_indexer_searchsorted( if limit is not None: raise ValueError( f"limit argument for {repr(method)} method only well-defined " - "if index and target are monotonic" - ) + "if index and target are monotonic") side: Literal["left", "right"] = "left" if method == "pad" else "right" # find exact matches first (this simplifies the algorithm) indexer = self.get_indexer(target) nonexact = indexer == -1 - indexer[nonexact] = self._searchsorted_monotonic(target[nonexact], side) + indexer[nonexact] = self._searchsorted_monotonic( + target[nonexact], side) if side == "left": # searchsorted returns "indices into a sorted array such that, # if the corresponding elements in v were inserted before the @@ -4067,9 +4090,8 @@ def _get_fill_indexer_searchsorted( return indexer @final - def _get_nearest_indexer( - self, target: Index, limit: int | None, tolerance - ) -> npt.NDArray[np.intp]: + def _get_nearest_indexer(self, target: Index, limit: int | None, + tolerance) -> npt.NDArray[np.intp]: """ Get the indexer for the nearest index labels; requires an index with values that can be subtracted from each other (e.g., not strings or @@ -4091,7 +4113,8 @@ def _get_nearest_indexer( right_indexer, ) if tolerance is not None: - indexer = self._filter_indexer_tolerance(target, indexer, tolerance) + indexer = self._filter_indexer_tolerance(target, indexer, + tolerance) return indexer @final @@ -4107,9 +4130,8 @@ def _filter_indexer_tolerance( return np.where(distance <= tolerance, indexer, -1) @final - def _difference_compat( - self, target: Index, indexer: npt.NDArray[np.intp] - ) -> ArrayLike: + def _difference_compat(self, target: Index, + indexer: npt.NDArray[np.intp]) -> ArrayLike: # Compatibility for PeriodArray, for which __sub__ returns an ndarray[object] # of DateOffset objects, which do not support __abs__ (and would be slow # if they did) @@ -4121,7 +4143,8 @@ def _difference_compat( diff = own_values[indexer] - target_values else: # error: Unsupported left operand type for - ("ExtensionArray") - diff = self._values[indexer] - target._values # type: ignore[operator] + diff = self._values[ + indexer] - target._values # type: ignore[operator] return abs(diff) # -------------------------------------------------------------------- @@ -4137,7 +4160,10 @@ def _validate_positional_slice(self, key: slice) -> None: self._validate_indexer("positional", key.stop, "iloc") self._validate_indexer("positional", key.step, "iloc") - def _convert_slice_indexer(self, key: slice, kind: str_t, is_frame: bool = False): + def _convert_slice_indexer(self, + key: slice, + kind: str_t, + is_frame: bool = False): """ Convert a slice indexer. @@ -4162,9 +4188,8 @@ def is_int(v): return v is None or is_integer(v) is_index_slice = is_int(start) and is_int(stop) and is_int(step) - is_positional = is_index_slice and not ( - self.is_integer() or self.is_categorical() - ) + is_positional = is_index_slice and not (self.is_integer() + or self.is_categorical()) if kind == "getitem": """ @@ -4179,18 +4204,14 @@ def is_int(v): # label-based vs positional is irrelevant pass elif isinstance(self, ABCRangeIndex) and self._range == range( - len(self) - ): + len(self)): # In this case there is no difference between label-based # and positional, so nothing will change. pass - elif ( - self.dtype.kind in ["i", "u"] - and self._is_strictly_monotonic_increasing - and len(self) > 0 - and self[0] == 0 - and self[-1] == len(self) - 1 - ): + elif (self.dtype.kind in ["i", "u"] + and self._is_strictly_monotonic_increasing + and len(self) > 0 and self[0] == 0 + and self[-1] == len(self) - 1): # We are range-like, e.g. created with Index(np.arange(N)) pass elif not is_index_slice: @@ -4254,8 +4275,7 @@ def _invalid_indexer(self, form: str_t, key) -> TypeError: """ return TypeError( f"cannot do {form} indexing on {type(self).__name__} with these " - f"indexers [{key}] of type {type(key).__name__}" - ) + f"indexers [{key}] of type {type(key).__name__}") # -------------------------------------------------------------------- # Reindex Methods @@ -4277,9 +4297,12 @@ def _validate_can_reindex(self, indexer: np.ndarray) -> None: if not self._index_as_unique and len(indexer): raise ValueError("cannot reindex on an axis with duplicate labels") - def reindex( - self, target, method=None, level=None, limit=None, tolerance=None - ) -> tuple[Index, npt.NDArray[np.intp] | None]: + def reindex(self, + target, + method=None, + level=None, + limit=None, + tolerance=None) -> tuple[Index, npt.NDArray[np.intp] | None]: """ Create index with target's values. @@ -4361,25 +4384,23 @@ def reindex( # TODO: tests where passing `keep_order=not self._is_multi` # makes a difference for non-MultiIndex case target, indexer, _ = self._join_level( - target, level, how="right", keep_order=not self._is_multi - ) + target, level, how="right", keep_order=not self._is_multi) else: if self.equals(target): indexer = None else: if self._index_as_unique: - indexer = self.get_indexer( - target, method=method, limit=limit, tolerance=tolerance - ) + indexer = self.get_indexer(target, + method=method, + limit=limit, + tolerance=tolerance) elif self._is_multi: raise ValueError("cannot handle a non-unique multi-index!") else: if method is not None or limit is not None: - raise ValueError( - "cannot reindex a non-unique index " - "with a method or limit" - ) + raise ValueError("cannot reindex a non-unique index " + "with a method or limit") indexer, _ = self.get_indexer_non_unique(target) if not self.is_unique: @@ -4445,7 +4466,7 @@ def _reindex_non_unique( cur_indexer = length[check] # Index constructor below will do inference - new_labels = np.empty((len(indexer),), dtype=object) + new_labels = np.empty((len(indexer), ), dtype=object) new_labels[cur_indexer] = cur_labels new_labels[missing_indexer] = missing_labels @@ -4512,18 +4533,22 @@ def join( """ other = ensure_index(other) - if isinstance(self, ABCDatetimeIndex) and isinstance(other, ABCDatetimeIndex): + if isinstance(self, ABCDatetimeIndex) and isinstance( + other, ABCDatetimeIndex): if (self.tz is None) ^ (other.tz is None): # Raise instead of casting to object below. - raise TypeError("Cannot join tz-naive with tz-aware DatetimeIndex") + raise TypeError( + "Cannot join tz-naive with tz-aware DatetimeIndex") if not self._is_multi and not other._is_multi: # We have specific handling for MultiIndex below pself, pother = self._maybe_promote(other) if pself is not self or pother is not other: - return pself.join( - pother, how=how, level=level, return_indexers=True, sort=sort - ) + return pself.join(pother, + how=how, + level=level, + return_indexers=True, + sort=sort) lindexer: np.ndarray | None rindexer: np.ndarray | None @@ -4564,9 +4589,10 @@ def join( if self._join_precedence < other._join_precedence: how = {"right": "left", "left": "right"}.get(how, how) - join_index, lidx, ridx = other.join( - self, how=how, level=level, return_indexers=True - ) + join_index, lidx, ridx = other.join(self, + how=how, + level=level, + return_indexers=True) lidx, ridx = ridx, lidx return join_index, lidx, ridx @@ -4588,14 +4614,11 @@ def join( else: return self._join_non_unique(other, how=how) elif ( - self.is_monotonic_increasing - and other.is_monotonic_increasing - and self._can_use_libjoin - and ( - not isinstance(self, ABCMultiIndex) - or not any(is_categorical_dtype(dtype) for dtype in self.dtypes) - ) - ): + self.is_monotonic_increasing and other.is_monotonic_increasing + and self._can_use_libjoin and + (not isinstance(self, ABCMultiIndex) + or not any(is_categorical_dtype(dtype) + for dtype in self.dtypes))): # Categorical is monotonic if data are ordered as categories, but join can # not handle this in case of not lexicographically monotonic GH#38502 try: @@ -4609,7 +4632,8 @@ def join( @final def _join_via_get_indexer( self, other: Index, how: str_t, sort: bool - ) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]: + ) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] + | None]: # Fallback if we do not have any fastpaths available based on # uniqueness/monotonicity @@ -4675,9 +4699,9 @@ def _join_multi(self, other: Index, how: str_t): # Join left and right # Join on same leveled multi-index frames is supported - join_idx, lidx, ridx = self_jnlevels.join( - other_jnlevels, how, return_indexers=True - ) + join_idx, lidx, ridx = self_jnlevels.join(other_jnlevels, + how, + return_indexers=True) # Restore the dropped levels # Returned index level order is @@ -4685,13 +4709,13 @@ def _join_multi(self, other: Index, how: str_t): dropped_names = ldrop_names + rdrop_names levels, codes, names = restore_dropped_levels_multijoin( - self, other, dropped_names, join_idx, lidx, ridx - ) + self, other, dropped_names, join_idx, lidx, ridx) # Re-create the multi-index - multi_join_idx = MultiIndex( - levels=levels, codes=codes, names=names, verify_integrity=False - ) + multi_join_idx = MultiIndex(levels=levels, + codes=codes, + names=names, + verify_integrity=False) multi_join_idx = multi_join_idx.remove_unused_levels() @@ -4717,16 +4741,19 @@ def _join_multi(self, other: Index, how: str_t): @final def _join_non_unique( - self, other: Index, how: str_t = "left" + self, + other: Index, + how: str_t = "left" ) -> tuple[Index, npt.NDArray[np.intp], npt.NDArray[np.intp]]: from pandas.core.reshape.merge import get_join_indexers # We only get here if dtypes match assert self.dtype == other.dtype - left_idx, right_idx = get_join_indexers( - [self._values], [other._values], how=how, sort=True - ) + left_idx, right_idx = get_join_indexers([self._values], + [other._values], + how=how, + sort=True) mask = left_idx == -1 join_array = self._values.take(left_idx) @@ -4748,8 +4775,13 @@ def _join_non_unique( @final def _join_level( - self, other: Index, level, how: str_t = "left", keep_order: bool = True - ) -> tuple[MultiIndex, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]: + self, + other: Index, + level, + how: str_t = "left", + keep_order: bool = True + ) -> tuple[MultiIndex, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] + | None]: """ The join method *only* affects the level of the resulting MultiIndex. Otherwise it just exactly aligns the Index data to the @@ -4792,7 +4824,8 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> npt.NDArray[np.intp]: return lib.get_level_sorter(lab, ensure_platform_int(starts)) if isinstance(self, MultiIndex) and isinstance(other, MultiIndex): - raise TypeError("Join on level between two MultiIndex objects is ambiguous") + raise TypeError( + "Join on level between two MultiIndex objects is ambiguous") left, right = self, other @@ -4808,24 +4841,23 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> npt.NDArray[np.intp]: if not right.is_unique: raise NotImplementedError( - "Index._join_level on non-unique index is not implemented" - ) + "Index._join_level on non-unique index is not implemented") new_level, left_lev_indexer, right_lev_indexer = old_level.join( - right, how=how, return_indexers=True - ) + right, how=how, return_indexers=True) if left_lev_indexer is None: if keep_order or len(left) == 0: left_indexer = None join_index = left else: # sort the leaves - left_indexer = _get_leaf_sorter(left.codes[: level + 1]) + left_indexer = _get_leaf_sorter(left.codes[:level + 1]) join_index = left[left_indexer] else: left_lev_indexer = ensure_platform_int(left_lev_indexer) - rev_indexer = lib.get_reverse_indexer(left_lev_indexer, len(old_level)) + rev_indexer = lib.get_reverse_indexer(left_lev_indexer, + len(old_level)) old_codes = left.codes[level] taker = old_codes[old_codes != -1] @@ -4847,14 +4879,14 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> npt.NDArray[np.intp]: else: # tie out the order with other if level == 0: # outer most level, take the fast route - max_new_lev = 0 if len(new_lev_codes) == 0 else new_lev_codes.max() + max_new_lev = 0 if len( + new_lev_codes) == 0 else new_lev_codes.max() ngroups = 1 + max_new_lev left_indexer, counts = libalgos.groupsort_indexer( - new_lev_codes, ngroups - ) + new_lev_codes, ngroups) # missing values are placed first; drop them! - left_indexer = left_indexer[counts[0] :] + left_indexer = left_indexer[counts[0]:] new_codes = [lab[left_indexer] for lab in new_codes] else: # sort the leaves @@ -4863,7 +4895,7 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> npt.NDArray[np.intp]: if not mask_all: new_codes = [lab[mask] for lab in new_codes] - left_indexer = _get_leaf_sorter(new_codes[: level + 1]) + left_indexer = _get_leaf_sorter(new_codes[:level + 1]) new_codes = [lab[left_indexer] for lab in new_codes] # left_indexers are w.r.t masked frame. @@ -4886,18 +4918,19 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> npt.NDArray[np.intp]: if flip_order: left_indexer, right_indexer = right_indexer, left_indexer - left_indexer = ( - None if left_indexer is None else ensure_platform_int(left_indexer) - ) - right_indexer = ( - None if right_indexer is None else ensure_platform_int(right_indexer) - ) + left_indexer = (None if left_indexer is None else + ensure_platform_int(left_indexer)) + right_indexer = (None if right_indexer is None else + ensure_platform_int(right_indexer)) return join_index, left_indexer, right_indexer @final def _join_monotonic( - self, other: Index, how: str_t = "left" - ) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]: + self, + other: Index, + how: str_t = "left" + ) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] + | None]: # We only get here with matching dtypes and both monotonic increasing assert other.dtype == self.dtype @@ -4940,14 +4973,16 @@ def _join_monotonic( ridx = None if ridx is None else ensure_platform_int(ridx) return join_index, lidx, ridx - def _wrap_joined_index(self: _IndexT, joined: ArrayLike, other: _IndexT) -> _IndexT: + def _wrap_joined_index(self: _IndexT, joined: ArrayLike, + other: _IndexT) -> _IndexT: assert other.dtype == self.dtype if isinstance(self, ABCMultiIndex): name = self.names if self.names == other.names else None # error: Incompatible return value type (got "MultiIndex", # expected "_IndexT") - return self._constructor(joined, name=name) # type: ignore[return-value] + return self._constructor(joined, + name=name) # type: ignore[return-value] else: name = get_op_result_name(self, other) return self._constructor._with_infer(joined, name=name) @@ -5082,8 +5117,7 @@ def where(self, cond, other=None) -> Index: """ if isinstance(self, ABCMultiIndex): raise NotImplementedError( - ".where is not supported for MultiIndex operations" - ) + ".where is not supported for MultiIndex operations") cond = np.asarray(cond, dtype=bool) return self.putmask(~cond, other) @@ -5095,16 +5129,13 @@ def _scalar_data_error(cls, data): # in order to keep mypy happy return TypeError( f"{cls.__name__}(...) must be called with a collection of some " - f"kind, {repr(data)} was passed" - ) + f"kind, {repr(data)} was passed") @final @classmethod def _string_data_error(cls, data): - raise TypeError( - "String dtype not supported, you may need " - "to explicitly cast to a numeric type" - ) + raise TypeError("String dtype not supported, you may need " + "to explicitly cast to a numeric type") def _validate_fill_value(self, value): """ @@ -5135,7 +5166,8 @@ def _require_scalar(self, value): operations without changing dtype. """ if not is_scalar(value): - raise TypeError(f"'value' must be a scalar, passed: {type(value).__name__}") + raise TypeError( + f"'value' must be a scalar, passed: {type(value).__name__}") return value def _is_memory_usage_qualified(self) -> bool: @@ -5355,7 +5387,8 @@ def putmask(self, mask, value) -> Index: values = self._values.copy() if isinstance(values, np.ndarray): - converted = setitem_datetimelike_compat(values, mask.sum(), converted) + converted = setitem_datetimelike_compat(values, mask.sum(), + converted) np.putmask(values, mask, converted) else: @@ -5463,14 +5496,9 @@ def identical(self, other) -> bool: If two Index objects have equal elements and same type True, otherwise False. """ - return ( - self.equals(other) - and all( - getattr(self, c, None) == getattr(other, c, None) - for c in self._comparables - ) - and type(self) == type(other) - ) + return (self.equals(other) and all( + getattr(self, c, None) == getattr(other, c, None) + for c in self._comparables) and type(self) == type(other)) @final def asof(self, label): @@ -5548,9 +5576,8 @@ def asof(self, label): return self[loc] - def asof_locs( - self, where: Index, mask: npt.NDArray[np.bool_] - ) -> npt.NDArray[np.intp]: + def asof_locs(self, where: Index, + mask: npt.NDArray[np.bool_]) -> npt.NDArray[np.intp]: """ Return the locations (indices) of labels in the index. @@ -5582,7 +5609,8 @@ def asof_locs( # types "Union[ExtensionArray, ndarray[Any, Any]]", "str" # TODO: will be fixed when ExtensionArray.searchsorted() is fixed locs = self._values[mask].searchsorted( - where._values, side="right" # type: ignore[call-overload] + where._values, + side="right" # type: ignore[call-overload] ) locs = np.where(locs > 0, locs - 1, 0) @@ -5661,9 +5689,10 @@ def sort_values( # GH 35584. Sort missing values according to na_position kwarg # ignore na_position for MultiIndex if not isinstance(self, ABCMultiIndex): - _as = nargsort( - items=idx, ascending=ascending, na_position=na_position, key=key - ) + _as = nargsort(items=idx, + ascending=ascending, + na_position=na_position, + key=key) else: _as = idx.argsort() if not ascending: @@ -5681,7 +5710,8 @@ def sort(self, *args, **kwargs): """ Use sort_values instead. """ - raise TypeError("cannot sort an Index object in-place, use sort_values instead") + raise TypeError( + "cannot sort an Index object in-place, use sort_values instead") def shift(self, periods=1, freq=None): """ @@ -5741,8 +5771,7 @@ def shift(self, periods=1, freq=None): """ raise NotImplementedError( f"This method is only implemented for DatetimeIndex, PeriodIndex and " - f"TimedeltaIndex; Got type {type(self).__name__}" - ) + f"TimedeltaIndex; Got type {type(self).__name__}") def argsort(self, *args, **kwargs) -> npt.NDArray[np.intp]: """ @@ -5861,10 +5890,8 @@ def set_value(self, arr, key, value): Only use this if you know what you're doing. """ warnings.warn( - ( - "The 'set_value' method is deprecated, and " - "will be removed in a future version." - ), + ("The 'set_value' method is deprecated, and " + "will be removed in a future version."), FutureWarning, stacklevel=find_stack_level(), ) @@ -5873,9 +5900,7 @@ def set_value(self, arr, key, value): raise ValueError arr[loc] = value - _index_shared_docs[ - "get_indexer_non_unique" - ] = """ + _index_shared_docs["get_indexer_non_unique"] = """ Compute indexer and mask for new index given the current index. The indexer should be then used as an input to ndarray.take to align the current data to the new index. @@ -5897,15 +5922,17 @@ def set_value(self, arr, key, value): @Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs) def get_indexer_non_unique( - self, target - ) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: + self, target) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: target = ensure_index(target) target = self._maybe_cast_listlike_indexer(target) - if not self._should_compare(target) and not is_interval_dtype(self.dtype): + if not self._should_compare(target) and not is_interval_dtype( + self.dtype): # IntervalIndex get special treatment bc numeric scalars can be # matched to Interval scalars - return self._get_indexer_non_comparable(target, method=None, unique=False) + return self._get_indexer_non_comparable(target, + method=None, + unique=False) pself, ptarget = self._maybe_promote(target) if pself is not self or ptarget is not target: @@ -5927,7 +5954,8 @@ def get_indexer_non_unique( engine = self._engine # Item "IndexEngine" of "Union[IndexEngine, ExtensionEngine]" has # no attribute "_extract_level_codes" - tgt_values = engine._extract_level_codes(target) # type: ignore[union-attr] + tgt_values = engine._extract_level_codes( + target) # type: ignore[union-attr] indexer, missing = self._engine.get_indexer_non_unique(tgt_values) return ensure_platform_int(indexer), ensure_platform_int(missing) @@ -5956,7 +5984,8 @@ def get_indexer_for(self, target) -> npt.NDArray[np.intp]: indexer, _ = self.get_indexer_non_unique(target) return indexer - def _get_indexer_strict(self, key, axis_name: str_t) -> tuple[Index, np.ndarray]: + def _get_indexer_strict(self, key, + axis_name: str_t) -> tuple[Index, np.ndarray]: """ Analogue to get_indexer that raises if any elements are missing. """ @@ -5979,9 +6008,9 @@ def _get_indexer_strict(self, key, axis_name: str_t) -> tuple[Index, np.ndarray] if keyarr.dtype.kind in ["m", "M"]: # DTI/TDI.take can infer a freq in some cases when we dont want one if isinstance(key, list) or ( - isinstance(key, type(self)) - # "Index" has no attribute "freq" - and key.freq is None # type: ignore[attr-defined] + isinstance(key, type(self)) + # "Index" has no attribute "freq" + and key.freq is None # type: ignore[attr-defined] ): keyarr = keyarr._with_freq(None) @@ -6024,21 +6053,23 @@ def _raise_if_missing(self, key, indexer, axis_name: str_t) -> None: # "Index" has no attribute "categories" [attr-defined] and is_interval_dtype( self.categories.dtype # type: ignore[attr-defined] - ) - ) + )) if nmissing == len(indexer): if use_interval_msg: key = list(key) raise KeyError(f"None of [{key}] are in the [{axis_name}]") - not_found = list(ensure_index(key)[missing_mask.nonzero()[0]].unique()) + not_found = list( + ensure_index(key)[missing_mask.nonzero()[0]].unique()) raise KeyError(f"{not_found} not in index") @overload def _get_indexer_non_comparable( - self, target: Index, method, unique: Literal[True] = ... - ) -> npt.NDArray[np.intp]: + self, + target: Index, + method, + unique: Literal[True] = ...) -> npt.NDArray[np.intp]: ... @overload @@ -6049,14 +6080,22 @@ def _get_indexer_non_comparable( @overload def _get_indexer_non_comparable( - self, target: Index, method, unique: bool = True - ) -> npt.NDArray[np.intp] | tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: + self, + target: Index, + method, + unique: bool = True + ) -> npt.NDArray[np.intp] | tuple[npt.NDArray[np.intp], + npt.NDArray[np.intp]]: ... @final def _get_indexer_non_comparable( - self, target: Index, method, unique: bool = True - ) -> npt.NDArray[np.intp] | tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: + self, + target: Index, + method, + unique: bool = True + ) -> npt.NDArray[np.intp] | tuple[npt.NDArray[np.intp], + npt.NDArray[np.intp]]: """ Called from get_indexer or get_indexer_non_unique when the target is of a non-comparable dtype. @@ -6082,7 +6121,8 @@ def _get_indexer_non_comparable( """ if method is not None: other = unpack_nested_dtype(target) - raise TypeError(f"Cannot compare dtypes {self.dtype} and {other.dtype}") + raise TypeError( + f"Cannot compare dtypes {self.dtype} and {other.dtype}") no_matches = -1 * np.ones(target.shape, dtype=np.intp) if unique: @@ -6112,21 +6152,21 @@ def _maybe_promote(self, other: Index) -> tuple[Index, Index]: if we can upcast the object-dtype one to improve performance. """ - if isinstance(self, ABCDatetimeIndex) and isinstance(other, ABCDatetimeIndex): - if ( - self.tz is not None - and other.tz is not None - and not tz_compare(self.tz, other.tz) - ): + if isinstance(self, ABCDatetimeIndex) and isinstance( + other, ABCDatetimeIndex): + if (self.tz is not None and other.tz is not None + and not tz_compare(self.tz, other.tz)): # standardize on UTC return self.tz_convert("UTC"), other.tz_convert("UTC") - elif self.inferred_type == "date" and isinstance(other, ABCDatetimeIndex): + elif self.inferred_type == "date" and isinstance( + other, ABCDatetimeIndex): try: return type(other)(self), other except OutOfBoundsDatetime: return self, other - elif self.inferred_type == "timedelta" and isinstance(other, ABCTimedeltaIndex): + elif self.inferred_type == "timedelta" and isinstance( + other, ABCTimedeltaIndex): # TODO: we dont have tests that get here return type(other)(self), other @@ -6140,7 +6180,8 @@ def _maybe_promote(self, other: Index) -> tuple[Index, Index]: elif self._is_multi and not other._is_multi: try: # "Type[Index]" has no attribute "from_tuples" - other = type(self).from_tuples(other) # type: ignore[attr-defined] + other = type(self).from_tuples( + other) # type: ignore[attr-defined] except (TypeError, ValueError): # let's instead try with a straight Index self = Index(self._values) @@ -6163,8 +6204,7 @@ def _find_common_type_compat(self, target) -> DtypeObj: if is_dtype_equal(self.dtype, dtype): raise NotImplementedError( "This should not be reached. Please report a bug at " - "github.com/pandas-dev/pandas" - ) + "github.com/pandas-dev/pandas") return dtype target_dtype, _ = infer_dtype_from(target, pandas_dtype=True) @@ -6176,9 +6216,8 @@ def _find_common_type_compat(self, target) -> DtypeObj: # * uint64 | signed int -> object # We may change union(float | [u]int) to go to object. if self.dtype == "uint64" or target_dtype == "uint64": - if is_signed_integer_dtype(self.dtype) or is_signed_integer_dtype( - target_dtype - ): + if is_signed_integer_dtype( + self.dtype) or is_signed_integer_dtype(target_dtype): return _dtype_obj dtype = find_common_type([self.dtype, target_dtype]) @@ -6191,9 +6230,9 @@ def _should_compare(self, other: Index) -> bool: Check if `self == other` can ever have non-False entries. """ - if (other.is_boolean() and self.is_numeric()) or ( - self.is_boolean() and other.is_numeric() - ): + if (other.is_boolean() + and self.is_numeric()) or (self.is_boolean() + and other.is_numeric()): # GH#16877 Treat boolean labels passed to a numeric index as not # found. Without this fix False and True would be treated as 0 and 1 # respectively. @@ -6281,20 +6320,24 @@ def map(self, mapper, na_action=None): # e.g. if we are floating and new_values is all ints, then we # don't want to cast back to floating. But if we are UInt64 # and new_values is all ints, we want to try. - same_dtype = lib.infer_dtype(new_values, skipna=False) == self.inferred_type + same_dtype = lib.infer_dtype(new_values, + skipna=False) == self.inferred_type if same_dtype: - new_values = maybe_cast_pointwise_result( - new_values, self.dtype, same_dtype=same_dtype - ) + new_values = maybe_cast_pointwise_result(new_values, + self.dtype, + same_dtype=same_dtype) if self._is_backward_compat_public_numeric_index and is_numeric_dtype( - new_values.dtype - ): - return self._constructor( - new_values, dtype=dtype, copy=False, name=self.name - ) + new_values.dtype): + return self._constructor(new_values, + dtype=dtype, + copy=False, + name=self.name) - return Index._with_infer(new_values, dtype=dtype, copy=False, name=self.name) + return Index._with_infer(new_values, + dtype=dtype, + copy=False, + name=self.name) # TODO: De-duplicate with map, xref GH#32349 @final @@ -6309,7 +6352,8 @@ def _transform_index(self, func, *, level=None) -> Index: if level is not None: # Caller is responsible for ensuring level is positional. items = [ - tuple(func(y) if i == level else y for i, y in enumerate(x)) + tuple( + func(y) if i == level else y for i, y in enumerate(x)) for x in self ] else: @@ -6527,7 +6571,9 @@ def _maybe_cast_slice_bound(self, label, side: str_t, kind=no_default): return label - def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"): + def _searchsorted_monotonic(self, + label, + side: Literal["left", "right"] = "left"): if self.is_monotonic_increasing: return self.searchsorted(label, side=side) elif self.is_monotonic_decreasing: @@ -6535,15 +6581,15 @@ def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left" # everything for it to work (element ordering, search side and # resulting value). pos = self[::-1].searchsorted( - label, side="right" if side == "left" else "left" - ) + label, side="right" if side == "left" else "left") return len(self) - pos raise ValueError("index must be monotonic increasing or decreasing") - def get_slice_bound( - self, label, side: Literal["left", "right"], kind=no_default - ) -> int: + def get_slice_bound(self, + label, + side: Literal["left", "right"], + kind=no_default) -> int: """ Calculate slice bound that corresponds to given label. @@ -6567,10 +6613,8 @@ def get_slice_bound( self._deprecated_arg(kind, "kind", "get_slice_bound") if side not in ("left", "right"): - raise ValueError( - "Invalid value for side kwarg, must be either " - f"'left' or 'right': {side}" - ) + raise ValueError("Invalid value for side kwarg, must be either " + f"'left' or 'right': {side}") original_label = label @@ -6594,10 +6638,8 @@ def get_slice_bound( assert is_bool_dtype(slc.dtype) slc = lib.maybe_booleans_to_slice(slc.view("u1")) if isinstance(slc, np.ndarray): - raise KeyError( - f"Cannot get {side} slice bound for non-unique " - f"label: {repr(original_label)}" - ) + raise KeyError(f"Cannot get {side} slice bound for non-unique " + f"label: {repr(original_label)}") if isinstance(slc, slice): if side == "left": @@ -6610,9 +6652,11 @@ def get_slice_bound( else: return slc - def slice_locs( - self, start=None, end=None, step=None, kind=no_default - ) -> tuple[int, int]: + def slice_locs(self, + start=None, + end=None, + step=None, + kind=no_default) -> tuple[int, int]: """ Compute slice locations for input labels. @@ -6655,7 +6699,8 @@ def slice_locs( # GH 16785: If start and end happen to be date strings with UTC offsets # attempt to parse and check that the offsets are the same - if isinstance(start, (str, datetime)) and isinstance(end, (str, datetime)): + if isinstance(start, + (str, datetime)) and isinstance(end, (str, datetime)): try: ts_start = Timestamp(start) ts_end = Timestamp(end) @@ -6663,7 +6708,8 @@ def slice_locs( pass else: if not tz_compare(ts_start.tzinfo, ts_end.tzinfo): - raise ValueError("Both dates must have the same UTC offset") + raise ValueError( + "Both dates must have the same UTC offset") start_slice = None if start is not None: @@ -6779,8 +6825,7 @@ def insert(self, loc: int, item) -> Index: return self.astype(dtype).insert(loc, item) if arr.dtype != object or not isinstance( - item, (tuple, np.datetime64, np.timedelta64) - ): + item, (tuple, np.datetime64, np.timedelta64)): # with object-dtype we need to worry about numpy incorrectly casting # dt64/td64 to integer, also about treating tuples as sequences # special-casing dt64/td64 https://github.com/numpy/numpy/issues/12550 @@ -6790,7 +6835,8 @@ def insert(self, loc: int, item) -> Index: else: # No overload variant of "insert" matches argument types # "ndarray[Any, Any]", "int", "None" [call-overload] - new_values = np.insert(arr, loc, None) # type: ignore[call-overload] + new_values = np.insert(arr, loc, + None) # type: ignore[call-overload] loc = loc if loc >= 0 else loc - 1 new_values[loc] = item @@ -6852,9 +6898,8 @@ def _cmp_method(self, other, op): arr[self.isna()] = True return arr - if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)) and len( - self - ) != len(other): + if isinstance(other, (np.ndarray, Index, ABCSeries, + ExtensionArray)) and len(self) != len(other): raise ValueError("Lengths must match to compare") if not isinstance(other, ABCMultiIndex): @@ -6870,7 +6915,8 @@ def _cmp_method(self, other, op): elif isinstance(self._values, ExtensionArray): result = op(self._values, other) - elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex): + elif is_object_dtype( + self.dtype) and not isinstance(self, ABCMultiIndex): # don't pass MultiIndex with np.errstate(all="ignore"): result = ops.comp_method_OBJECT_ARRAY(op, self._values, other) @@ -6890,11 +6936,8 @@ def _construct_result(self, result, name): return Index._with_infer(result, name=name) def _arith_method(self, other, op): - if ( - isinstance(other, Index) - and is_object_dtype(other.dtype) - and type(other) is not Index - ): + if (isinstance(other, Index) and is_object_dtype(other.dtype) + and type(other) is not Index): # We return NotImplemented for object-dtype index *subclasses* so they have # a chance to implement ops before we unwrap them. # See https://github.com/pandas-dev/pandas/issues/31109 @@ -7019,13 +7062,10 @@ def _maybe_disable_logical_methods(self, opname: str_t) -> None: """ raise if this Index subclass does not support any or all. """ - if ( - isinstance(self, ABCMultiIndex) - or needs_i8_conversion(self.dtype) - or is_interval_dtype(self.dtype) - or is_categorical_dtype(self.dtype) - or is_float_dtype(self.dtype) - ): + if (isinstance(self, ABCMultiIndex) or needs_i8_conversion(self.dtype) + or is_interval_dtype(self.dtype) + or is_categorical_dtype(self.dtype) + or is_float_dtype(self.dtype)): # This call will raise make_invalid_op(opname)(self) @@ -7075,7 +7115,8 @@ def min(self, axis=None, skipna=True, *args, **kwargs): if not self._is_multi and not isinstance(self._values, np.ndarray): # "ExtensionArray" has no attribute "min" - return self._values.min(skipna=skipna) # type: ignore[attr-defined] + return self._values.min( + skipna=skipna) # type: ignore[attr-defined] return super().min(skipna=skipna) @@ -7101,7 +7142,8 @@ def max(self, axis=None, skipna=True, *args, **kwargs): if not self._is_multi and not isinstance(self._values, np.ndarray): # "ExtensionArray" has no attribute "max" - return self._values.max(skipna=skipna) # type: ignore[attr-defined] + return self._values.max( + skipna=skipna) # type: ignore[attr-defined] return super().max(skipna=skipna) @@ -7114,7 +7156,7 @@ def shape(self) -> Shape: Return a tuple of the shape of the underlying data. """ # See GH#27775, GH#27384 for history/reasoning in how this is defined. - return (len(self),) + return (len(self), ) @final def _deprecated_arg(self, value, name: str_t, methodname: str_t) -> None: @@ -7170,7 +7212,8 @@ def ensure_index_from_sequences(sequences, names=None) -> Index: return MultiIndex.from_arrays(sequences, names=names) -def ensure_index(index_like: AnyArrayLike | Sequence, copy: bool = False) -> Index: +def ensure_index(index_like: AnyArrayLike | Sequence, + copy: bool = False) -> Index: """ Ensure that we have an index from some index-like object. @@ -7224,7 +7267,9 @@ def ensure_index(index_like: AnyArrayLike | Sequence, copy: bool = False) -> Ind return MultiIndex.from_arrays(index_like) else: - return Index._with_infer(index_like, copy=copy, tupleize_cols=False) + return Index._with_infer(index_like, + copy=copy, + tupleize_cols=False) else: return Index._with_infer(index_like, copy=copy) @@ -7285,13 +7330,12 @@ def maybe_extract_name(name, obj, cls) -> Hashable: "In a future version, passing an object-dtype arraylike to pd.Index will " "not infer numeric values to numeric dtype (matching the Series behavior). " "To retain the old behavior, explicitly pass the desired dtype or use the " - "desired Index subclass" -) + "desired Index subclass") -def _maybe_cast_data_without_dtype( - subarr: np.ndarray, cast_numeric_deprecated: bool = True -) -> ArrayLike: +def _maybe_cast_data_without_dtype(subarr: np.ndarray, + cast_numeric_deprecated: bool = True + ) -> ArrayLike: """ If we have an arraylike input but no passed dtype, try to infer a supported dtype. diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index b9f52556eb79e..31c8e1b6a98e0 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -167,23 +167,23 @@ def format( header = [] if name: header.append( - ibase.pprint_thing(self.name, escape_chars=("\t", "\r", "\n")) - if self.name is not None - else "" - ) + ibase.pprint_thing(self.name, escape_chars=( + "\t", "\r", "\n")) if self.name is not None else "") if formatter is not None: return header + list(self.map(formatter)) - return self._format_with_header(header, na_rep=na_rep, date_format=date_format) + return self._format_with_header(header, + na_rep=na_rep, + date_format=date_format) - def _format_with_header( - self, header: list[str], na_rep: str = "NaT", date_format: str | None = None - ) -> list[str]: + def _format_with_header(self, + header: list[str], + na_rep: str = "NaT", + date_format: str | None = None) -> list[str]: # matches base class except for whitespace padding and date_format return header + list( - self._format_native_types(na_rep=na_rep, date_format=date_format) - ) + self._format_native_types(na_rep=na_rep, date_format=date_format)) @property def _formatter_func(self): @@ -259,9 +259,8 @@ def _partial_date_slice( if self.is_monotonic_increasing: - if len(self) and ( - (t1 < self[0] and t2 < self[0]) or (t1 > self[-1] and t2 > self[-1]) - ): + if len(self) and ((t1 < self[0] and t2 < self[0]) or + (t1 > self[-1] and t2 > self[-1])): # we are out of range raise KeyError @@ -432,9 +431,9 @@ def _wrap_range_setop(self, other, res_i8): # This raising is incorrect, as 'on_freq' is incorrect. This will # be fixed by GH#41493 res_values = res_i8.values.view(self._data._ndarray.dtype) - result = type(self._data)._simple_new( - res_values, dtype=self.dtype, freq=new_freq - ) + result = type(self._data)._simple_new(res_values, + dtype=self.dtype, + freq=new_freq) return self._wrap_setop_result(other, result) def _range_intersect(self, other, sort): @@ -637,8 +636,7 @@ def _get_delete_freq(self, loc: int | slice | Sequence[int]): # type "Union[slice, ndarray]", variable has type # "Union[int, slice, Sequence[int]]") loc = lib.maybe_indices_to_slice( # type: ignore[assignment] - np.asarray(loc, dtype=np.intp), len(self) - ) + np.asarray(loc, dtype=np.intp), len(self)) if isinstance(loc, slice) and loc.step in (1, None): if loc.start in (0, None) or loc.stop in (len(self), None): freq = self.freq @@ -689,13 +687,18 @@ def insert(self, loc: int, item): # NDArray-Like Methods @Appender(_index_shared_docs["take"] % _index_doc_kwargs) - def take(self, indices, axis=0, allow_fill=True, fill_value=None, **kwargs): + def take(self, + indices, + axis=0, + allow_fill=True, + fill_value=None, + **kwargs): nv.validate_take((), kwargs) indices = np.asarray(indices, dtype=np.intp) - result = NDArrayBackedExtensionIndex.take( - self, indices, axis, allow_fill, fill_value, **kwargs - ) + result = NDArrayBackedExtensionIndex.take(self, indices, axis, + allow_fill, fill_value, + **kwargs) maybe_slice = lib.maybe_indices_to_slice(indices, len(self)) if isinstance(maybe_slice, slice): diff --git a/pandas/core/indexes/range.py b/pandas/core/indexes/range.py index 98e032c60054c..109dee2437fc6 100644 --- a/pandas/core/indexes/range.py +++ b/pandas/core/indexes/range.py @@ -146,9 +146,10 @@ def __new__( return cls._simple_new(rng, name=name) @classmethod - def from_range( - cls, data: range, name=None, dtype: Dtype | None = None - ) -> RangeIndex: + def from_range(cls, + data: range, + name=None, + dtype: Dtype | None = None) -> RangeIndex: """ Create RangeIndex from a range object. @@ -159,8 +160,7 @@ def from_range( if not isinstance(data, range): raise TypeError( f"{cls.__name__}(...) must be called with object coercible to a " - f"range, {repr(data)} was passed" - ) + f"range, {repr(data)} was passed") cls._validate_dtype(dtype) return cls._simple_new(data, name=name) @@ -232,11 +232,9 @@ def _format_with_header(self, header: list[str], na_rep: str) -> list[str]: return header + [f"{x:<{max_length}}" for x in self._range] # -------------------------------------------------------------------- - _deprecation_message = ( - "RangeIndex.{} is deprecated and will be " - "removed in a future version. Use RangeIndex.{} " - "instead" - ) + _deprecation_message = ("RangeIndex.{} is deprecated and will be " + "removed in a future version. Use RangeIndex.{} " + "instead") @property def start(self) -> int: @@ -316,8 +314,7 @@ def nbytes(self) -> int: rng = self._range return getsizeof(rng) + sum( getsizeof(getattr(rng, attr_name)) - for attr_name in ["start", "stop", "step"] - ) + for attr_name in ["start", "stop", "step"]) def memory_usage(self, deep: bool = False) -> int: """ @@ -397,9 +394,10 @@ def _get_indexer( tolerance=None, ) -> npt.NDArray[np.intp]: if com.any_not_none(method, tolerance, limit): - return super()._get_indexer( - target, method=method, tolerance=tolerance, limit=limit - ) + return super()._get_indexer(target, + method=method, + tolerance=tolerance, + limit=limit) if self.step > 0: start, stop, step = self.start, self.stop, self.step @@ -466,7 +464,8 @@ def _minmax(self, meth: str): no_steps = len(self) - 1 if no_steps == -1: return np.nan - elif (meth == "min" and self.step > 0) or (meth == "max" and self.step < 0): + elif (meth == "min" and self.step > 0) or (meth == "max" + and self.step < 0): return self.start return self.start + self.step * no_steps @@ -510,7 +509,9 @@ def argsort(self, *args, **kwargs) -> npt.NDArray[np.intp]: return result def factorize( - self, sort: bool = False, na_sentinel: int | None = -1 + self, + sort: bool = False, + na_sentinel: int | None = -1 ) -> tuple[npt.NDArray[np.intp], RangeIndex]: codes = np.arange(len(self), dtype=np.intp) uniques = self @@ -591,7 +592,8 @@ def _intersection(self, other: Index, sort=False): # calculate parameters for the RangeIndex describing the # intersection disregarding the lower bounds - tmp_start = first.start + (second.start - first.start) * first.step // gcd * s + tmp_start = first.start + (second.start - + first.start) * first.step // gcd * s new_step = first.step * second.step // gcd new_range = range(tmp_start, int_high, new_step) new_index = self._simple_new(new_range) @@ -669,34 +671,26 @@ def _union(self, other: Index, sort): start_r = min(start_s, start_o) end_r = max(end_s, end_o) if step_o == step_s: - if ( - (start_s - start_o) % step_s == 0 - and (start_s - end_o) <= step_s - and (start_o - end_s) <= step_s - ): + if ((start_s - start_o) % step_s == 0 + and (start_s - end_o) <= step_s + and (start_o - end_s) <= step_s): return type(self)(start_r, end_r + step_s, step_s) - if ( - (step_s % 2 == 0) - and (abs(start_s - start_o) == step_s / 2) - and (abs(end_s - end_o) == step_s / 2) - ): + if ((step_s % 2 == 0) + and (abs(start_s - start_o) == step_s / 2) + and (abs(end_s - end_o) == step_s / 2)): # e.g. range(0, 10, 2) and range(1, 11, 2) # but not range(0, 20, 4) and range(1, 21, 4) GH#44019 return type(self)(start_r, end_r + step_s / 2, step_s / 2) elif step_o % step_s == 0: - if ( - (start_o - start_s) % step_s == 0 - and (start_o + step_s >= start_s) - and (end_o - step_s <= end_s) - ): + if ((start_o - start_s) % step_s == 0 + and (start_o + step_s >= start_s) + and (end_o - step_s <= end_s)): return type(self)(start_r, end_r + step_s, step_s) elif step_s % step_o == 0: - if ( - (start_s - start_o) % step_o == 0 - and (start_s + step_o >= start_o) - and (end_s - step_o <= end_o) - ): + if ((start_s - start_o) % step_o == 0 + and (start_s + step_o >= start_o) + and (end_s - step_o <= end_o)): return type(self)(start_r, end_r + step_o, step_o) return super()._union(other, sort=sort) @@ -740,14 +734,16 @@ def _difference(self, other, sort=None): else: return super()._difference(other, sort=sort) - elif len(overlap) == 2 and overlap[0] == first[0] and overlap[-1] == first[-1]: + elif len(overlap) == 2 and overlap[0] == first[0] and overlap[ + -1] == first[-1]: # e.g. range(-8, 20, 7) and range(13, -9, -3) return self[1:-1] if overlap.step == first.step: if overlap[0] == first.start: # The difference is everything after the intersection - new_rng = range(overlap[-1] + first.step, first.stop, first.step) + new_rng = range(overlap[-1] + first.step, first.stop, + first.step) elif overlap[-1] == first[-1]: # The difference is everything before the intersection new_rng = range(first.start, overlap[0], first.step) @@ -766,11 +762,13 @@ def _difference(self, other, sort=None): assert len(self) > 1 if overlap.step == first.step * 2: - if overlap[0] == first[0] and overlap[-1] in (first[-1], first[-2]): + if overlap[0] == first[0] and overlap[-1] in (first[-1], + first[-2]): # e.g. range(1, 10, 1) and range(1, 10, 2) new_rng = first[1::2] - elif overlap[0] == first[1] and overlap[-1] in (first[-1], first[-2]): + elif overlap[0] == first[1] and overlap[-1] in (first[-1], + first[-2]): # e.g. range(1, 10, 1) and range(2, 10, 2) new_rng = first[::2] @@ -788,7 +786,10 @@ def _difference(self, other, sort=None): return new_index - def symmetric_difference(self, other, result_name: Hashable = None, sort=None): + def symmetric_difference(self, + other, + result_name: Hashable = None, + sort=None): if not isinstance(other, RangeIndex) or sort is not None: return super().symmetric_difference(other, result_name, sort) @@ -816,7 +817,8 @@ def delete(self, loc) -> Index: # type: ignore[override] return self[::2] elif lib.is_list_like(loc): - slc = lib.maybe_indices_to_slice(np.asarray(loc, dtype=np.intp), len(self)) + slc = lib.maybe_indices_to_slice(np.asarray(loc, dtype=np.intp), + len(self)) if isinstance(slc, slice): # defer to RangeIndex._difference, which is optimized to return @@ -886,11 +888,12 @@ def _concat(self, indexes: list[Index], name: Hashable) -> Index: step = rng.start - start - non_consecutive = (step != rng.step and len(rng) > 1) or ( - next_ is not None and rng.start != next_ - ) + non_consecutive = (step != rng.step + and len(rng) > 1) or (next_ is not None + and rng.start != next_) if non_consecutive: - result = Int64Index(np.concatenate([x._values for x in rng_indexes])) + result = Int64Index( + np.concatenate([x._values for x in rng_indexes])) return result.rename(name) if step is not None: @@ -932,12 +935,10 @@ def __getitem__(self, key): f"index {key} is out of bounds for axis 0 with size {len(self)}" ) from err elif is_scalar(key): - raise IndexError( - "only integers, slices (`:`), " - "ellipsis (`...`), numpy.newaxis (`None`) " - "and integer or boolean " - "arrays are valid indices" - ) + raise IndexError("only integers, slices (`:`), " + "ellipsis (`...`), numpy.newaxis (`None`) " + "and integer or boolean " + "arrays are valid indices") # fall back to Int64Index return super().__getitem__(key) @@ -952,7 +953,9 @@ def _getitem_slice(self: RangeIndex, slobj: slice) -> RangeIndex: def __floordiv__(self, other): if is_integer(other) and other != 0: - if len(self) == 0 or self.start % other == 0 and self.step % other == 0: + if len( + self + ) == 0 or self.start % other == 0 and self.step % other == 0: start = self.start // other step = self.step // other stop = start + len(self) * step @@ -1003,14 +1006,14 @@ def _arith_method(self, other, op): return super()._arith_method(other, op) if op in [ - operator.pow, - ops.rpow, - operator.mod, - ops.rmod, - operator.floordiv, - ops.rfloordiv, - divmod, - ops.rdivmod, + operator.pow, + ops.rpow, + operator.mod, + ops.rmod, + operator.floordiv, + ops.rfloordiv, + divmod, + ops.rdivmod, ]: return super()._arith_method(other, op) diff --git a/pandas/core/nanops.py b/pandas/core/nanops.py index 65ad48c786dc2..689f6120ca1d5 100644 --- a/pandas/core/nanops.py +++ b/pandas/core/nanops.py @@ -72,14 +72,17 @@ def set_use_bottleneck(v: bool = True) -> None: class disallow: + def __init__(self, *dtypes: Dtype): super().__init__() self.dtypes = tuple(pandas_dtype(dtype).type for dtype in dtypes) def check(self, obj) -> bool: - return hasattr(obj, "dtype") and issubclass(obj.dtype.type, self.dtypes) + return hasattr(obj, "dtype") and issubclass(obj.dtype.type, + self.dtypes) def __call__(self, f: F) -> F: + @functools.wraps(f) def _f(*args, **kwargs): obj_iter = itertools.chain(args, kwargs.values()) @@ -104,6 +107,7 @@ def _f(*args, **kwargs): class bottleneck_switch: + def __init__(self, name=None, **kwargs): self.name = name self.kwargs = kwargs @@ -138,7 +142,8 @@ def f( # It *may* just be `var` return _na_for_min_count(values, axis) - if _USE_BOTTLENECK and skipna and _bn_ok_dtype(values.dtype, bn_name): + if _USE_BOTTLENECK and skipna and _bn_ok_dtype( + values.dtype, bn_name): if kwds.get("mask", None) is None: # `mask` is not recognised by bottleneck, would raise # TypeError if called @@ -188,9 +193,9 @@ def _has_infs(result) -> bool: return False -def _get_fill_value( - dtype: DtypeObj, fill_value: Scalar | None = None, fill_value_typ=None -): +def _get_fill_value(dtype: DtypeObj, + fill_value: Scalar | None = None, + fill_value_typ=None): """return the correct fill value for the dtype of the values""" if fill_value is not None: return fill_value @@ -211,8 +216,8 @@ def _get_fill_value( def _maybe_get_mask( - values: np.ndarray, skipna: bool, mask: npt.NDArray[np.bool_] | None -) -> npt.NDArray[np.bool_] | None: + values: np.ndarray, skipna: bool, + mask: npt.NDArray[np.bool_] | None) -> npt.NDArray[np.bool_] | None: """ Compute a mask if and only if necessary. @@ -303,7 +308,8 @@ def _get_values( assert is_scalar(fill_value) # error: Incompatible types in assignment (expression has type "Union[Any, # Union[ExtensionArray, ndarray]]", variable has type "ndarray") - values = extract_array(values, extract_numpy=True) # type: ignore[assignment] + values = extract_array(values, + extract_numpy=True) # type: ignore[assignment] mask = _maybe_get_mask(values, skipna, mask) @@ -320,9 +326,9 @@ def _get_values( # get our fill value (in case we need to provide an alternative # dtype for it) - fill_value = _get_fill_value( - dtype, fill_value=fill_value, fill_value_typ=fill_value_typ - ) + fill_value = _get_fill_value(dtype, + fill_value=fill_value, + fill_value_typ=fill_value_typ) if skipna and (mask is not None) and (fill_value is not None): if mask.any(): @@ -413,14 +419,16 @@ def new_func( result = _wrap_results(result, orig_values.dtype, fill_value=iNaT) if not skipna: assert mask is not None # checked above - result = _mask_datetimelike_result(result, axis, mask, orig_values) + result = _mask_datetimelike_result(result, axis, mask, + orig_values) return result return cast(F, new_func) -def _na_for_min_count(values: np.ndarray, axis: int | None) -> Scalar | np.ndarray: +def _na_for_min_count(values: np.ndarray, + axis: int | None) -> Scalar | np.ndarray: """ Return the missing value for `values`. @@ -446,7 +454,7 @@ def _na_for_min_count(values: np.ndarray, axis: int | None) -> Scalar | np.ndarr elif axis is None: return fill_value else: - result_shape = values.shape[:axis] + values.shape[axis + 1 :] + result_shape = values.shape[:axis] + values.shape[axis + 1:] return np.full(result_shape, fill_value, dtype=values.dtype) @@ -460,21 +468,17 @@ def maybe_operate_rowwise(func: F) -> F: @functools.wraps(func) def newfunc(values: np.ndarray, *, axis: int | None = None, **kwargs): - if ( - axis == 1 - and values.ndim == 2 - and values.flags["C_CONTIGUOUS"] - # only takes this path for wide arrays (long dataframes), for threshold see - # https://github.com/pandas-dev/pandas/pull/43311#issuecomment-974891737 - and (values.shape[1] / 1000) > values.shape[0] - and values.dtype != object - and values.dtype != bool - ): + if (axis == 1 and values.ndim == 2 and values.flags["C_CONTIGUOUS"] + # only takes this path for wide arrays (long dataframes), for threshold see + # https://github.com/pandas-dev/pandas/pull/43311#issuecomment-974891737 + and (values.shape[1] / 1000) > values.shape[0] and + values.dtype != object and values.dtype != bool): arrs = list(values) if kwargs.get("mask") is not None: mask = kwargs.pop("mask") results = [ - func(arrs[i], mask=mask[i], **kwargs) for i in range(len(arrs)) + func(arrs[i], mask=mask[i], **kwargs) + for i in range(len(arrs)) ] else: results = [func(x, **kwargs) for x in arrs] @@ -519,7 +523,10 @@ def nanany( >>> nanops.nanany(s) False """ - values, _, _, _, _ = _get_values(values, skipna, fill_value=False, mask=mask) + values, _, _, _, _ = _get_values(values, + skipna, + fill_value=False, + mask=mask) # For object type, any won't necessarily return # boolean values (numpy/numpy#4352) @@ -565,7 +572,10 @@ def nanall( >>> nanops.nanall(s) False """ - values, _, _, _, _ = _get_values(values, skipna, fill_value=True, mask=mask) + values, _, _, _, _ = _get_values(values, + skipna, + fill_value=True, + mask=mask) # For object type, all won't necessarily return # boolean values (numpy/numpy#4352) @@ -611,9 +621,10 @@ def nansum( >>> nanops.nansum(s) 3.0 """ - values, mask, dtype, dtype_max, _ = _get_values( - values, skipna, fill_value=0, mask=mask - ) + values, mask, dtype, dtype_max, _ = _get_values(values, + skipna, + fill_value=0, + mask=mask) dtype_sum = dtype_max if is_float_dtype(dtype): dtype_sum = dtype @@ -621,7 +632,11 @@ def nansum( dtype_sum = np.dtype(np.float64) the_sum = values.sum(axis, dtype=dtype_sum) - the_sum = _maybe_null_out(the_sum, axis, mask, values.shape, min_count=min_count) + the_sum = _maybe_null_out(the_sum, + axis, + mask, + values.shape, + min_count=min_count) return the_sum @@ -679,9 +694,10 @@ def nanmean( >>> nanops.nanmean(s) 1.5 """ - values, mask, dtype, dtype_max, _ = _get_values( - values, skipna, fill_value=0, mask=mask - ) + values, mask, dtype, dtype_max, _ = _get_values(values, + skipna, + fill_value=0, + mask=mask) dtype_sum = dtype_max dtype_count = np.dtype(np.float64) @@ -770,7 +786,8 @@ def get_median(x): # fastpath for the skipna case with warnings.catch_warnings(): # Suppress RuntimeWarning about All-NaN slice - warnings.filterwarnings("ignore", "All-NaN slice encountered") + warnings.filterwarnings("ignore", + "All-NaN slice encountered") res = np.nanmedian(values, axis) else: @@ -778,7 +795,8 @@ def get_median(x): # empty set so return nans of shape "everything but the passed axis" # since "axis" is where the reduction would occur if we had a nonempty # array - res = get_empty_reduction_result(values.shape, axis, np.float_, np.nan) + res = get_empty_reduction_result(values.shape, axis, np.float_, + np.nan) else: # otherwise return a scalar value @@ -895,7 +913,8 @@ def nanstd(values, *, axis=None, skipna=True, ddof=1, mask=None): orig_dtype = values.dtype values, mask, _, _, _ = _get_values(values, skipna, mask=mask) - result = np.sqrt(nanvar(values, axis=axis, skipna=skipna, ddof=ddof, mask=mask)) + result = np.sqrt( + nanvar(values, axis=axis, skipna=skipna, ddof=ddof, mask=mask)) return _wrap_results(result, orig_dtype) @@ -938,7 +957,8 @@ def nanvar(values, *, axis=None, skipna=True, ddof=1, mask=None): values[mask] = np.nan if is_float_dtype(values.dtype): - count, d = _get_counts_nanvar(values.shape, mask, axis, ddof, values.dtype) + count, d = _get_counts_nanvar(values.shape, mask, axis, ddof, + values.dtype) else: count, d = _get_counts_nanvar(values.shape, mask, axis, ddof) @@ -955,7 +975,7 @@ def nanvar(values, *, axis=None, skipna=True, ddof=1, mask=None): avg = _ensure_numeric(values.sum(axis=axis, dtype=np.float64)) / count if axis is not None: avg = np.expand_dims(avg, axis) - sqr = _ensure_numeric((avg - values) ** 2) + sqr = _ensure_numeric((avg - values)**2) if mask is not None: np.putmask(sqr, mask, 0) result = sqr.sum(axis=axis, dtype=np.float64) / d @@ -1019,6 +1039,7 @@ def nansem( def _nanminmax(meth, fill_value_typ): + @bottleneck_switch(name="nan" + meth) @_datetimelike_compat def reduction( @@ -1030,8 +1051,7 @@ def reduction( ) -> Dtype: values, mask, dtype, dtype_max, fill_value = _get_values( - values, skipna, fill_value_typ=fill_value_typ, mask=mask - ) + values, skipna, fill_value_typ=fill_value_typ, mask=mask) if (axis is not None and values.shape[axis] == 0) or values.size == 0: try: @@ -1091,7 +1111,10 @@ def nanargmax( >>> nanops.nanargmax(arr, axis=1) array([2, 2, 1, 1]) """ - values, mask, _, _, _ = _get_values(values, True, fill_value_typ="-inf", mask=mask) + values, mask, _, _, _ = _get_values(values, + True, + fill_value_typ="-inf", + mask=mask) # error: Need type annotation for 'result' result = values.argmax(axis) # type: ignore[var-annotated] result = _maybe_arg_null_out(result, axis, mask, skipna) @@ -1137,7 +1160,10 @@ def nanargmin( >>> nanops.nanargmin(arr, axis=1) array([0, 0, 1, 1]) """ - values, mask, _, _, _ = _get_values(values, True, fill_value_typ="+inf", mask=mask) + values, mask, _, _, _ = _get_values(values, + True, + fill_value_typ="+inf", + mask=mask) # error: Need type annotation for 'result' result = values.argmin(axis) # type: ignore[var-annotated] result = _maybe_arg_null_out(result, axis, mask, skipna) @@ -1183,7 +1209,8 @@ def nanskew( """ # error: Incompatible types in assignment (expression has type "Union[Any, # Union[ExtensionArray, ndarray]]", variable has type "ndarray") - values = extract_array(values, extract_numpy=True) # type: ignore[assignment] + values = extract_array(values, + extract_numpy=True) # type: ignore[assignment] mask = _maybe_get_mask(values, skipna, mask) if not is_float_dtype(values.dtype): values = values.astype("f8") @@ -1215,7 +1242,7 @@ def nanskew( m3 = _zero_out_fperr(m3) with np.errstate(invalid="ignore", divide="ignore"): - result = (count * (count - 1) ** 0.5 / (count - 2)) * (m3 / m2**1.5) + result = (count * (count - 1)**0.5 / (count - 2)) * (m3 / m2**1.5) dtype = values.dtype if is_float_dtype(dtype): @@ -1271,7 +1298,8 @@ def nankurt( """ # error: Incompatible types in assignment (expression has type "Union[Any, # Union[ExtensionArray, ndarray]]", variable has type "ndarray") - values = extract_array(values, extract_numpy=True) # type: ignore[assignment] + values = extract_array(values, + extract_numpy=True) # type: ignore[assignment] mask = _maybe_get_mask(values, skipna, mask) if not is_float_dtype(values.dtype): values = values.astype("f8") @@ -1296,7 +1324,7 @@ def nankurt( m4 = adjusted4.sum(axis, dtype=np.float64) with np.errstate(invalid="ignore", divide="ignore"): - adj = 3 * (count - 1) ** 2 / ((count - 2) * (count - 3)) + adj = 3 * (count - 1)**2 / ((count - 2) * (count - 3)) numerator = count * (count + 1) * (count - 1) * m4 denominator = (count - 2) * (count - 3) * m2**2 @@ -1370,8 +1398,11 @@ def nanprod( # error: Incompatible return value type (got "Union[ndarray, float]", expected # "float") return _maybe_null_out( # type: ignore[return-value] - result, axis, mask, values.shape, min_count=min_count - ) + result, + axis, + mask, + values.shape, + min_count=min_count) def _maybe_arg_null_out( @@ -1402,10 +1433,10 @@ def _maybe_arg_null_out( def _get_counts( - values_shape: Shape, - mask: npt.NDArray[np.bool_] | None, - axis: int | None, - dtype: np.dtype = np.dtype(np.float64), + values_shape: Shape, + mask: npt.NDArray[np.bool_] | None, + axis: int | None, + dtype: np.dtype = np.dtype(np.float64), ) -> int | float | np.ndarray: """ Get the count of non-null values along an axis @@ -1461,7 +1492,7 @@ def _maybe_null_out( else: # we have no nulls, kept mask=None in _maybe_get_mask below_count = shape[axis] - min_count < 0 - new_shape = shape[:axis] + shape[axis + 1 :] + new_shape = shape[:axis] + shape[axis + 1:] null_mask = np.broadcast_to(below_count, new_shape) if np.any(null_mask): @@ -1481,9 +1512,9 @@ def _maybe_null_out( return result -def check_below_min_count( - shape: tuple[int, ...], mask: npt.NDArray[np.bool_] | None, min_count: int -) -> bool: +def check_below_min_count(shape: tuple[int, ...], + mask: npt.NDArray[np.bool_] | None, + min_count: int) -> bool: """ Check for the `min_count` keyword. Returns True if below `min_count` (when missing value should be returned from the reduction). @@ -1522,9 +1553,11 @@ def _zero_out_fperr(arg): @disallow("M8", "m8") -def nancorr( - a: np.ndarray, b: np.ndarray, *, method="pearson", min_periods: int | None = None -): +def nancorr(a: np.ndarray, + b: np.ndarray, + *, + method="pearson", + min_periods: int | None = None): """ a, b: ndarrays """ @@ -1570,10 +1603,8 @@ def func(a, b): elif callable(method): return method - raise ValueError( - f"Unknown method '{method}', expected one of " - "'kendall', 'spearman', 'pearson', or callable" - ) + raise ValueError(f"Unknown method '{method}', expected one of " + "'kendall', 'spearman', 'pearson', or callable") @disallow("M8", "m8") @@ -1613,7 +1644,8 @@ def _ensure_numeric(x): x = x.astype(np.float64) except ValueError as err: # GH#29941 we get here with object arrays containing strs - raise TypeError(f"Could not convert {x} to numeric") from err + raise TypeError( + f"Could not convert {x} to numeric") from err else: if not np.any(np.imag(x)): x = x.real @@ -1634,6 +1666,7 @@ def _ensure_numeric(x): def make_nancomp(op): + def f(x, y): xmask = isna(x) ymask = isna(y) @@ -1712,7 +1745,7 @@ def na_accum_func(values: ArrayLike, accum_func, *, skipna: bool) -> ArrayLike: nz = (~np.asarray(mask)).nonzero()[0] if len(nz): # everything up to the first non-na entry stays NaT - result[: nz[0]] = iNaT + result[:nz[0]] = iNaT if isinstance(values.dtype, np.dtype): result = result.view(orig_dtype) @@ -1720,12 +1753,13 @@ def na_accum_func(values: ArrayLike, accum_func, *, skipna: bool) -> ArrayLike: # DatetimeArray/TimedeltaArray # TODO: have this case go through a DTA method? # For DatetimeTZDtype, view result as M8[ns] - npdtype = orig_dtype if isinstance(orig_dtype, np.dtype) else "M8[ns]" + npdtype = orig_dtype if isinstance(orig_dtype, + np.dtype) else "M8[ns]" # Item "type" of "Union[Type[ExtensionArray], Type[ndarray[Any, Any]]]" # has no attribute "_simple_new" result = type(values)._simple_new( # type: ignore[union-attr] - result.view(npdtype), dtype=orig_dtype - ) + result.view(npdtype), + dtype=orig_dtype) elif skipna and not issubclass(values.dtype.type, (np.integer, np.bool_)): vals = values.copy() diff --git a/pandas/core/ops/__init__.py b/pandas/core/ops/__init__.py index 79eb50b0368e7..28852b4341dc3 100644 --- a/pandas/core/ops/__init__.py +++ b/pandas/core/ops/__init__.py @@ -96,10 +96,8 @@ "rdivmod", } - COMPARISON_BINOPS: set[str] = {"eq", "ne", "lt", "gt", "le", "ge"} - # ----------------------------------------------------------------------------- # Masking NA values and fallbacks for operations numpy does not support @@ -204,9 +202,11 @@ def flex_wrapper(self, other, level=None, fill_value=None, axis=0): # DataFrame -def align_method_FRAME( - left, right, axis, flex: bool | None = False, level: Level = None -): +def align_method_FRAME(left, + right, + axis, + flex: bool | None = False, + level: Level = None): """ Convert rhs to meet lhs dims if input is list, tuple or np.ndarray. @@ -231,14 +231,13 @@ def to_series(right): if axis is not None and left._get_axis_name(axis) == "index": if len(left.index) != len(right): raise ValueError( - msg.format(req_len=len(left.index), given_len=len(right)) - ) + msg.format(req_len=len(left.index), given_len=len(right))) right = left._constructor_sliced(right, index=left.index) else: if len(left.columns) != len(right): raise ValueError( - msg.format(req_len=len(left.columns), given_len=len(right)) - ) + msg.format(req_len=len(left.columns), + given_len=len(right))) right = left._constructor_sliced(right, index=left.columns) return right @@ -249,30 +248,31 @@ def to_series(right): elif right.ndim == 2: if right.shape == left.shape: - right = left._constructor(right, index=left.index, columns=left.columns) + right = left._constructor(right, + index=left.index, + columns=left.columns) elif right.shape[0] == left.shape[0] and right.shape[1] == 1: # Broadcast across columns right = np.broadcast_to(right, left.shape) - right = left._constructor(right, index=left.index, columns=left.columns) + right = left._constructor(right, + index=left.index, + columns=left.columns) elif right.shape[1] == left.shape[1] and right.shape[0] == 1: # Broadcast along rows right = to_series(right[0, :]) else: - raise ValueError( - "Unable to coerce to DataFrame, shape " - f"must be {left.shape}: given {right.shape}" - ) + raise ValueError("Unable to coerce to DataFrame, shape " + f"must be {left.shape}: given {right.shape}") elif right.ndim > 2: - raise ValueError( - "Unable to coerce to Series/DataFrame, " - f"dimension must be <= 2: {right.shape}" - ) + raise ValueError("Unable to coerce to Series/DataFrame, " + f"dimension must be <= 2: {right.shape}") - elif is_list_like(right) and not isinstance(right, (ABCSeries, ABCDataFrame)): + elif is_list_like(right) and not isinstance(right, + (ABCSeries, ABCDataFrame)): # GH 36702. Raise when attempting arithmetic with list of array-like. if any(is_array_like(el) for el in right): raise ValueError( @@ -284,11 +284,13 @@ def to_series(right): if flex is not None and isinstance(right, ABCDataFrame): if not left._indexed_same(right): if flex: - left, right = left.align(right, join="outer", level=level, copy=False) + left, right = left.align(right, + join="outer", + level=level, + copy=False) else: raise ValueError( - "Can only compare identically-labeled DataFrame objects" - ) + "Can only compare identically-labeled DataFrame objects") elif isinstance(right, ABCSeries): # axis=1 is default for DataFrame-with-Series op axis = left._get_axis_number(axis) if axis is not None else 1 @@ -304,17 +306,18 @@ def to_series(right): stacklevel=find_stack_level(), ) - left, right = left.align( - right, join="outer", axis=axis, level=level, copy=False - ) + left, right = left.align(right, + join="outer", + axis=axis, + level=level, + copy=False) right = _maybe_align_series_as_frame(left, right, axis) return left, right -def should_reindex_frame_op( - left: DataFrame, right, op, axis, default_axis, fill_value, level -) -> bool: +def should_reindex_frame_op(left: DataFrame, right, op, axis, default_axis, + fill_value, level) -> bool: """ Check if this is an operation between DataFrames that will need to reindex. """ @@ -334,14 +337,16 @@ def should_reindex_frame_op( left_uniques = left.columns.unique() right_uniques = right.columns.unique() cols = left_uniques.intersection(right_uniques) - if len(cols) and not (cols.equals(left_uniques) and cols.equals(right_uniques)): + if len(cols) and not (cols.equals(left_uniques) + and cols.equals(right_uniques)): # TODO: is there a shortcut available when len(cols) == 0? return True return False -def frame_arith_method_with_reindex(left: DataFrame, right: DataFrame, op) -> DataFrame: +def frame_arith_method_with_reindex(left: DataFrame, right: DataFrame, + op) -> DataFrame: """ For DataFrame-with-DataFrame operations that require reindexing, operate only on shared columns, then reindex. @@ -357,9 +362,10 @@ def frame_arith_method_with_reindex(left: DataFrame, right: DataFrame, op) -> Da DataFrame """ # GH#31623, only operate on shared columns - cols, lcols, rcols = left.columns.join( - right.columns, how="inner", level=None, return_indexers=True - ) + cols, lcols, rcols = left.columns.join(right.columns, + how="inner", + level=None, + return_indexers=True) new_left = left.iloc[:, lcols] new_right = right.iloc[:, rcols] @@ -367,18 +373,18 @@ def frame_arith_method_with_reindex(left: DataFrame, right: DataFrame, op) -> Da # Do the join on the columns instead of using align_method_FRAME # to avoid constructing two potentially large/sparse DataFrames - join_columns, _, _ = left.columns.join( - right.columns, how="outer", level=None, return_indexers=True - ) + join_columns, _, _ = left.columns.join(right.columns, + how="outer", + level=None, + return_indexers=True) if result.columns.has_duplicates: # Avoid reindexing with a duplicate axis. # https://github.com/pandas-dev/pandas/issues/35194 indexer, _ = result.columns.get_indexer_non_unique(join_columns) indexer = algorithms.unique1d(indexer) - result = result._reindex_with_indexers( - {1: [join_columns, indexer]}, allow_dups=True - ) + result = result._reindex_with_indexers({1: [join_columns, indexer]}, + allow_dups=True) else: result = result.reindex(join_columns, axis=1) @@ -418,20 +424,24 @@ def flex_arith_method_FRAME(op): @Appender(doc) def f(self, other, axis=default_axis, level=None, fill_value=None): - if should_reindex_frame_op( - self, other, op, axis, default_axis, fill_value, level - ): + if should_reindex_frame_op(self, other, op, axis, default_axis, + fill_value, level): return frame_arith_method_with_reindex(self, other, op) if isinstance(other, ABCSeries) and fill_value is not None: # TODO: We could allow this in cases where we end up going # through the DataFrame path - raise NotImplementedError(f"fill_value {fill_value} not supported.") + raise NotImplementedError( + f"fill_value {fill_value} not supported.") axis = self._get_axis_number(axis) if axis is not None else 1 other = maybe_prepare_scalar_for_op(other, self.shape) - self, other = align_method_FRAME(self, other, axis, flex=True, level=level) + self, other = align_method_FRAME(self, + other, + axis, + flex=True, + level=level) if isinstance(other, ABCDataFrame): # Another DataFrame @@ -457,15 +467,18 @@ def flex_comp_method_FRAME(op): op_name = op.__name__.strip("_") default_axis = "columns" # because we are "flex" - doc = _flex_comp_doc_FRAME.format( - op_name=op_name, desc=_op_descriptions[op_name]["desc"] - ) + doc = _flex_comp_doc_FRAME.format(op_name=op_name, + desc=_op_descriptions[op_name]["desc"]) @Appender(doc) def f(self, other, axis=default_axis, level=None): axis = self._get_axis_number(axis) if axis is not None else 1 - self, other = align_method_FRAME(self, other, axis, flex=True, level=level) + self, other = align_method_FRAME(self, + other, + axis, + flex=True, + level=level) new_data = self._dispatch_frame_op(other, op, axis=axis) return self._construct_result(new_data) diff --git a/pandas/io/formats/excel.py b/pandas/io/formats/excel.py index 1792c81992354..0f853207cc79f 100644 --- a/pandas/io/formats/excel.py +++ b/pandas/io/formats/excel.py @@ -72,6 +72,7 @@ def __init__( class CssExcelCell(ExcelCell): + def __init__( self, row: int, @@ -86,11 +87,14 @@ def __init__( ): if css_styles and css_converter: css = ";".join( - [a + ":" + str(v) for (a, v) in css_styles[css_row, css_col]] - ) + [a + ":" + str(v) for (a, v) in css_styles[css_row, css_col]]) style = css_converter(css) - return super().__init__(row=row, col=col, val=val, style=style, **kwargs) + return super().__init__(row=row, + col=col, + val=val, + style=style, + **kwargs) class CSSToExcelConverter: @@ -186,7 +190,8 @@ def __call__(self, declarations_str: str) -> dict[str, dict[str, str]]: properties = self.compute_css(declarations_str, self.inherited) return self.build_xlstyle(properties) - def build_xlstyle(self, props: Mapping[str, str]) -> dict[str, dict[str, str]]: + def build_xlstyle(self, props: Mapping[str, + str]) -> dict[str, dict[str, str]]: out = { "alignment": self.build_alignment(props), "border": self.build_border(props), @@ -210,7 +215,8 @@ def remove_none(d: dict[str, str]) -> None: remove_none(out) return out - def build_alignment(self, props: Mapping[str, str]) -> dict[str, bool | str | None]: + def build_alignment( + self, props: Mapping[str, str]) -> dict[str, bool | str | None]: # TODO: text-indent, padding-left -> alignment.indent return { "horizontal": props.get("text-align"), @@ -230,21 +236,24 @@ def _get_is_wrap_text(self, props: Mapping[str, str]) -> bool | None: return bool(props["white-space"] not in ("nowrap", "pre", "pre-line")) def build_border( - self, props: Mapping[str, str] - ) -> dict[str, dict[str, str | None]]: + self, props: Mapping[str, + str]) -> dict[str, dict[str, str | None]]: return { side: { - "style": self._border_style( + "style": + self._border_style( props.get(f"border-{side}-style"), props.get(f"border-{side}-width"), self.color_to_excel(props.get(f"border-{side}-color")), ), - "color": self.color_to_excel(props.get(f"border-{side}-color")), + "color": + self.color_to_excel(props.get(f"border-{side}-color")), } for side in ["top", "right", "bottom", "left"] } - def _border_style(self, style: str | None, width: str | None, color: str | None): + def _border_style(self, style: str | None, width: str | None, + color: str | None): # convert styles and widths to openxml, one of: # 'dashDot' # 'dashDotDot' @@ -313,16 +322,21 @@ def build_fill(self, props: Mapping[str, str]): # -excel-pattern-bgcolor and -excel-pattern-type fill_color = props.get("background-color") if fill_color not in (None, "transparent", "none"): - return {"fgColor": self.color_to_excel(fill_color), "patternType": "solid"} + return { + "fgColor": self.color_to_excel(fill_color), + "patternType": "solid" + } - def build_number_format(self, props: Mapping[str, str]) -> dict[str, str | None]: + def build_number_format(self, + props: Mapping[str, str]) -> dict[str, str | None]: fc = props.get("number-format") fc = fc.replace("§", ";") if isinstance(fc, str) else fc return {"format_code": fc} def build_font( - self, props: Mapping[str, str] - ) -> dict[str, bool | int | float | str | None]: + self, + props: Mapping[str, + str]) -> dict[str, bool | int | float | str | None]: font_names = self._get_font_names(props) decoration = self._get_decoration(props) return { @@ -515,7 +529,8 @@ def __init__( if len(Index(cols).intersection(df.columns)) != len(set(cols)): # Deprecated in GH#17295, enforced in 1.0.0 - raise KeyError("Not all names specified in 'columns' are found") + raise KeyError( + "Not all names specified in 'columns' are found") self.df = df.reindex(columns=cols) @@ -530,14 +545,19 @@ def __init__( @property def header_style(self): return { - "font": {"bold": True}, + "font": { + "bold": True + }, "borders": { "top": "thin", "right": "thin", "bottom": "thin", "left": "thin", }, - "alignment": {"horizontal": "center", "vertical": "top"}, + "alignment": { + "horizontal": "center", + "vertical": "top" + }, } def _format_value(self, val): @@ -551,11 +571,9 @@ def _format_value(self, val): elif self.float_format is not None: val = float(self.float_format % val) if getattr(val, "tzinfo", None) is not None: - raise ValueError( - "Excel does not support datetimes with " - "timezones. Please ensure that datetimes " - "are timezone unaware before writing to Excel." - ) + raise ValueError("Excel does not support datetimes with " + "timezones. Please ensure that datetimes " + "are timezone unaware before writing to Excel.") return val def _format_header_mi(self) -> Iterable[ExcelCell]: @@ -563,16 +581,15 @@ def _format_header_mi(self) -> Iterable[ExcelCell]: if not self.index: raise NotImplementedError( "Writing to Excel with MultiIndex columns and no " - "index ('index'=False) is not yet implemented." - ) + "index ('index'=False) is not yet implemented.") if not (self._has_aliases or self.header): return columns = self.columns - level_strs = columns.format( - sparsify=self.merge_cells, adjoin=False, names=False - ) + level_strs = columns.format(sparsify=self.merge_cells, + adjoin=False, + names=False) level_lengths = get_level_lengths(level_strs) coloffset = 0 lnum = 0 @@ -591,8 +608,7 @@ def _format_header_mi(self) -> Iterable[ExcelCell]: ) for lnum, (spans, levels, level_codes) in enumerate( - zip(level_lengths, columns.levels, columns.codes) - ): + zip(level_lengths, columns.levels, columns.codes)): values = levels.take(level_codes) for i, span_val in spans.items(): mergestart, mergeend = None, None @@ -640,10 +656,8 @@ def _format_header_regular(self) -> Iterable[ExcelCell]: if self._has_aliases: self.header = cast(Sequence, self.header) if len(self.header) != len(self.columns): - raise ValueError( - f"Writing {len(self.columns)} cols " - f"but got {len(self.header)} aliases" - ) + raise ValueError(f"Writing {len(self.columns)} cols " + f"but got {len(self.header)} aliases") else: colnames = self.header @@ -670,14 +684,12 @@ def _format_header(self) -> Iterable[ExcelCell]: gen2: Iterable[ExcelCell] = () if self.df.index.names: - row = [x if x is not None else "" for x in self.df.index.names] + [ - "" - ] * len(self.columns) + row = [x if x is not None else "" + for x in self.df.index.names] + [""] * len(self.columns) if reduce(lambda x, y: x and y, map(lambda x: x != "", row)): - gen2 = ( - ExcelCell(self.rowcounter, colindex, val, self.header_style) - for colindex, val in enumerate(row) - ) + gen2 = (ExcelCell(self.rowcounter, colindex, val, + self.header_style) + for colindex, val in enumerate(row)) self.rowcounter += 1 return itertools.chain(gen, gen2) @@ -696,8 +708,7 @@ def _format_regular_rows(self) -> Iterable[ExcelCell]: # check aliases # if list only take first as this is not a MultiIndex if self.index_label and isinstance( - self.index_label, (list, tuple, np.ndarray, Index) - ): + self.index_label, (list, tuple, np.ndarray, Index)): index_label = self.index_label[0] # if string good to go elif self.index_label and isinstance(self.index_label, str): @@ -709,7 +720,8 @@ def _format_regular_rows(self) -> Iterable[ExcelCell]: self.rowcounter += 1 if index_label and self.header is not False: - yield ExcelCell(self.rowcounter - 1, 0, index_label, self.header_style) + yield ExcelCell(self.rowcounter - 1, 0, index_label, + self.header_style) # write index_values index_values = self.df.index @@ -743,8 +755,7 @@ def _format_hierarchical_rows(self) -> Iterable[ExcelCell]: index_labels = self.df.index.names # check for aliases if self.index_label and isinstance( - self.index_label, (list, tuple, np.ndarray, Index) - ): + self.index_label, (list, tuple, np.ndarray, Index)): index_labels = self.index_label # MultiIndex columns require an extra row @@ -758,18 +769,19 @@ def _format_hierarchical_rows(self) -> Iterable[ExcelCell]: if com.any_not_none(*index_labels) and self.header is not False: for cidx, name in enumerate(index_labels): - yield ExcelCell(self.rowcounter - 1, cidx, name, self.header_style) + yield ExcelCell(self.rowcounter - 1, cidx, name, + self.header_style) if self.merge_cells: # Format hierarchical rows as merged cells. - level_strs = self.df.index.format( - sparsify=True, adjoin=False, names=False - ) + level_strs = self.df.index.format(sparsify=True, + adjoin=False, + names=False) level_lengths = get_level_lengths(level_strs) - for spans, levels, level_codes in zip( - level_lengths, self.df.index.levels, self.df.index.codes - ): + for spans, levels, level_codes in zip(level_lengths, + self.df.index.levels, + self.df.index.codes): values = levels.take( level_codes, @@ -836,7 +848,8 @@ def _generate_body(self, coloffset: int) -> Iterable[ExcelCell]: ) def get_formatted_cells(self) -> Iterable[ExcelCell]: - for cell in itertools.chain(self._format_header(), self._format_body()): + for cell in itertools.chain(self._format_header(), + self._format_body()): cell.val = self._format_value(cell.val) yield cell @@ -884,8 +897,7 @@ def write( if num_rows > self.max_rows or num_cols > self.max_cols: raise ValueError( f"This sheet is too large! Your sheet size is: {num_rows}, {num_cols} " - f"Max sheet size is: {self.max_rows}, {self.max_cols}" - ) + f"Max sheet size is: {self.max_rows}, {self.max_cols}") formatted_cells = self.get_formatted_cells() if isinstance(writer, ExcelWriter): @@ -894,8 +906,9 @@ def write( # error: Cannot instantiate abstract class 'ExcelWriter' with abstract # attributes 'engine', 'save', 'supported_extensions' and 'write_cells' writer = ExcelWriter( # type: ignore[abstract] - writer, engine=engine, storage_options=storage_options - ) + writer, + engine=engine, + storage_options=storage_options) need_save = True try: diff --git a/pandas/io/formats/style.py b/pandas/io/formats/style.py index db02b0fd39e21..092812fdf3bdd 100644 --- a/pandas/io/formats/style.py +++ b/pandas/io/formats/style.py @@ -50,7 +50,8 @@ from pandas.io.formats.format import save_to_buffer -jinja2 = import_optional_dependency("jinja2", extra="DataFrame.style requires jinja2.") +jinja2 = import_optional_dependency("jinja2", + extra="DataFrame.style requires jinja2.") from pandas.io.formats.style_render import ( CSSProperties, @@ -358,12 +359,12 @@ def concat(self, other: Styler) -> Styler: if not isinstance(other, Styler): raise TypeError("`other` must be of type `Styler`") if not self.data.columns.equals(other.data.columns): - raise ValueError("`other.data` must have same columns as `Styler.data`") + raise ValueError( + "`other.data` must have same columns as `Styler.data`") if not self.data.index.nlevels == other.data.index.nlevels: raise ValueError( "number of index levels must be same in `other` " - "as in `Styler`. See documentation for suggestions." - ) + "as in `Styler`. See documentation for suggestions.") self.concatenated = other return self @@ -527,8 +528,7 @@ def set_tooltips( # tooltips not optimised for individual cell check. requires reasonable # redesign and more extensive code for a feature that might be rarely used. raise NotImplementedError( - "Tooltips can only render with 'cell_ids' is True." - ) + "Tooltips can only render with 'cell_ids' is True.") if not ttips.index.is_unique or not ttips.columns.is_unique: raise KeyError( "Tooltips render only if `ttips` has unique index and columns." @@ -1058,18 +1058,19 @@ def to_latex( .. figure:: ../../_static/style/latex_stocks.png """ - obj = self._copy(deepcopy=True) # manipulate table_styles on obj, not self + obj = self._copy( + deepcopy=True) # manipulate table_styles on obj, not self - table_selectors = ( - [style["selector"] for style in self.table_styles] - if self.table_styles is not None - else [] - ) + table_selectors = ([style["selector"] for style in self.table_styles] + if self.table_styles is not None else []) if column_format is not None: # add more recent setting to table_styles obj.set_table_styles( - [{"selector": "column_format", "props": f":{column_format}"}], + [{ + "selector": "column_format", + "props": f":{column_format}" + }], overwrite=False, ) elif "column_format" in table_selectors: @@ -1085,17 +1086,22 @@ def to_latex( column_format += "" if self.hide_index_[level] else "l" for ci, _ in enumerate(self.data.columns): if ci not in self.hidden_columns: - column_format += ( - ("r" if not siunitx else "S") if ci in numeric_cols else "l" - ) + column_format += (("r" if not siunitx else "S") + if ci in numeric_cols else "l") obj.set_table_styles( - [{"selector": "column_format", "props": f":{column_format}"}], + [{ + "selector": "column_format", + "props": f":{column_format}" + }], overwrite=False, ) if position: obj.set_table_styles( - [{"selector": "position", "props": f":{position}"}], + [{ + "selector": "position", + "props": f":{position}" + }], overwrite=False, ) @@ -1104,31 +1110,47 @@ def to_latex( raise ValueError( "`position_float` cannot be used in 'longtable' `environment`" ) - if position_float not in ["raggedright", "raggedleft", "centering"]: - raise ValueError( - f"`position_float` should be one of " - f"'raggedright', 'raggedleft', 'centering', " - f"got: '{position_float}'" - ) + if position_float not in [ + "raggedright", "raggedleft", "centering" + ]: + raise ValueError(f"`position_float` should be one of " + f"'raggedright', 'raggedleft', 'centering', " + f"got: '{position_float}'") obj.set_table_styles( - [{"selector": "position_float", "props": f":{position_float}"}], + [{ + "selector": "position_float", + "props": f":{position_float}" + }], overwrite=False, ) - hrules = get_option("styler.latex.hrules") if hrules is None else hrules + hrules = get_option( + "styler.latex.hrules") if hrules is None else hrules if hrules: obj.set_table_styles( [ - {"selector": "toprule", "props": ":toprule"}, - {"selector": "midrule", "props": ":midrule"}, - {"selector": "bottomrule", "props": ":bottomrule"}, + { + "selector": "toprule", + "props": ":toprule" + }, + { + "selector": "midrule", + "props": ":midrule" + }, + { + "selector": "bottomrule", + "props": ":bottomrule" + }, ], overwrite=False, ) if label: obj.set_table_styles( - [{"selector": "label", "props": f":{{{label.replace(':', '§')}}}"}], + [{ + "selector": "label", + "props": f":{{{label.replace(':', '§')}}}" + }], overwrite=False, ) @@ -1140,8 +1162,10 @@ def to_latex( if sparse_columns is None: sparse_columns = get_option("styler.sparse.columns") environment = environment or get_option("styler.latex.environment") - multicol_align = multicol_align or get_option("styler.latex.multicol_align") - multirow_align = multirow_align or get_option("styler.latex.multirow_align") + multicol_align = multicol_align or get_option( + "styler.latex.multicol_align") + multirow_align = multirow_align or get_option( + "styler.latex.multirow_align") latex = obj._render_latex( sparse_index=sparse_index, sparse_columns=sparse_columns, @@ -1154,9 +1178,9 @@ def to_latex( ) encoding = encoding or get_option("styler.render.encoding") - return save_to_buffer( - latex, buf=buf, encoding=None if buf is None else encoding - ) + return save_to_buffer(latex, + buf=buf, + encoding=None if buf is None else encoding) @Substitution(buf=buf, encoding=encoding) def to_html( @@ -1252,7 +1276,8 @@ def to_html( -------- DataFrame.to_html: Write a DataFrame to a file, buffer or string in HTML format. """ - obj = self._copy(deepcopy=True) # manipulate table_styles on obj, not self + obj = self._copy( + deepcopy=True) # manipulate table_styles on obj, not self if table_uuid: obj.set_uuid(table_uuid) @@ -1266,9 +1291,11 @@ def to_html( sparse_columns = get_option("styler.sparse.columns") if bold_headers: - obj.set_table_styles( - [{"selector": "th", "props": "font-weight: bold;"}], overwrite=False - ) + obj.set_table_styles([{ + "selector": "th", + "props": "font-weight: bold;" + }], + overwrite=False) if caption is not None: obj.set_caption(caption) @@ -1286,9 +1313,9 @@ def to_html( **kwargs, ) - return save_to_buffer( - html, buf=buf, encoding=(encoding if buf is not None else None) - ) + return save_to_buffer(html, + buf=buf, + encoding=(encoding if buf is not None else None)) @Substitution(buf=buf, encoding=encoding) def to_string( @@ -1351,9 +1378,9 @@ def to_string( max_cols=max_columns, delimiter=delimiter, ) - return save_to_buffer( - text, buf=buf, encoding=(encoding if buf is not None else None) - ) + return save_to_buffer(text, + buf=buf, + encoding=(encoding if buf is not None else None)) def set_td_classes(self, classes: DataFrame) -> Styler: """ @@ -1445,10 +1472,8 @@ def _update_ctx(self, attrs: DataFrame) -> None: matter. """ if not self.index.is_unique or not self.columns.is_unique: - raise KeyError( - "`Styler.apply` and `.applymap` are not compatible " - "with non-unique index or columns." - ) + raise KeyError("`Styler.apply` and `.applymap` are not compatible " + "with non-unique index or columns.") for cn in attrs.columns: j = self.columns.get_loc(cn) @@ -1513,7 +1538,8 @@ def _copy(self, deepcopy: bool = False) -> Styler: """ # GH 40675 styler = Styler( - self.data, # populates attributes 'data', 'columns', 'index' as shallow + self. + data, # populates attributes 'data', 'columns', 'index' as shallow ) shallow = [ # simple string or boolean immutables "hide_index_", @@ -1571,7 +1597,8 @@ def clear(self) -> None: # create default GH 40675 clean_copy = Styler(self.data, uuid=self.uuid) clean_attrs = [a for a in clean_copy.__dict__ if not callable(a)] - self_attrs = [a for a in self.__dict__ if not callable(a)] # maybe more attrs + self_attrs = [a for a in self.__dict__ + if not callable(a)] # maybe more attrs for attr in clean_attrs: setattr(self, attr, getattr(clean_copy, attr)) for attr in set(self_attrs).difference(clean_attrs): @@ -1595,15 +1622,15 @@ def _apply( if not isinstance(result, np.ndarray): raise TypeError( f"Function {repr(func)} must return a DataFrame or ndarray " - f"when passed to `Styler.apply` with axis=None" - ) + f"when passed to `Styler.apply` with axis=None") if not (data.shape == result.shape): raise ValueError( f"Function {repr(func)} returned ndarray with wrong shape.\n" f"Result has shape: {result.shape}\n" - f"Expected shape: {data.shape}" - ) - result = DataFrame(result, index=data.index, columns=data.columns) + f"Expected shape: {data.shape}") + result = DataFrame(result, + index=data.index, + columns=data.columns) else: axis = self.data._get_axis_number(axis) if axis == 0: @@ -1615,8 +1642,7 @@ def _apply( raise ValueError( f"Function {repr(func)} resulted in the apply method collapsing to a " f"Series.\nUsually, this is the result of the function returning a " - f"single value, instead of list-like." - ) + f"single value, instead of list-like.") msg = ( f"Function {repr(func)} created invalid {{0}} labels.\nUsually, this is " f"the result of the function returning a " @@ -1625,14 +1651,14 @@ def _apply( f"cannot be mapped to labels, possibly due to applying the function along " f"the wrong axis.\n" f"Result {{0}} has shape: {{1}}\n" - f"Expected {{0}} shape: {{2}}" - ) + f"Expected {{0}} shape: {{2}}") if not all(result.index.isin(data.index)): - raise ValueError(msg.format("index", result.index.shape, data.index.shape)) + raise ValueError( + msg.format("index", result.index.shape, data.index.shape)) if not all(result.columns.isin(data.columns)): raise ValueError( - msg.format("columns", result.columns.shape, data.columns.shape) - ) + msg.format("columns", result.columns.shape, + data.columns.shape)) self._update_ctx(result) return self @@ -1724,9 +1750,8 @@ def apply( See `Table Visualization <../../user_guide/style.ipynb>`_ user guide for more details. """ - self._todo.append( - (lambda instance: getattr(instance, "_apply"), (func, axis, subset), kwargs) - ) + self._todo.append((lambda instance: getattr(instance, "_apply"), + (func, axis, subset), kwargs)) return self def _apply_index( @@ -1826,13 +1851,11 @@ def apply_index( .. figure:: ../../_static/style/appmaphead2.png """ - self._todo.append( - ( - lambda instance: getattr(instance, "_apply_index"), - (func, axis, level, "apply"), - kwargs, - ) - ) + self._todo.append(( + lambda instance: getattr(instance, "_apply_index"), + (func, axis, level, "apply"), + kwargs, + )) return self @doc( @@ -1855,18 +1878,17 @@ def applymap_index( level: Level | list[Level] | None = None, **kwargs, ) -> Styler: - self._todo.append( - ( - lambda instance: getattr(instance, "_apply_index"), - (func, axis, level, "applymap"), - kwargs, - ) - ) + self._todo.append(( + lambda instance: getattr(instance, "_apply_index"), + (func, axis, level, "applymap"), + kwargs, + )) return self - def _applymap( - self, func: Callable, subset: Subset | None = None, **kwargs - ) -> Styler: + def _applymap(self, + func: Callable, + subset: Subset | None = None, + **kwargs) -> Styler: func = partial(func, **kwargs) # applymap doesn't take kwargs? if subset is None: subset = IndexSlice[:] @@ -1876,9 +1898,10 @@ def _applymap( return self @Substitution(subset=subset) - def applymap( - self, func: Callable, subset: Subset | None = None, **kwargs - ) -> Styler: + def applymap(self, + func: Callable, + subset: Subset | None = None, + **kwargs) -> Styler: """ Apply a CSS-styling function elementwise. @@ -1932,9 +1955,8 @@ def applymap( See `Table Visualization <../../user_guide/style.ipynb>`_ user guide for more details. """ - self._todo.append( - (lambda instance: getattr(instance, "_applymap"), (func, subset), kwargs) - ) + self._todo.append((lambda instance: getattr(instance, "_applymap"), + (func, subset), kwargs)) return self @Substitution(subset=subset) @@ -2170,12 +2192,10 @@ def use(self, styles: dict[str, Any]) -> Styler: """ self._todo.extend(styles.get("apply", [])) table_attributes: str = self.table_attributes or "" - obj_table_atts: str = ( - "" - if styles.get("table_attributes") is None - else str(styles.get("table_attributes")) - ) - self.set_table_attributes((table_attributes + " " + obj_table_atts).strip()) + obj_table_atts: str = ("" if styles.get("table_attributes") is None + else str(styles.get("table_attributes"))) + self.set_table_attributes( + (table_attributes + " " + obj_table_atts).strip()) if styles.get("table_styles"): self.set_table_styles(styles.get("table_styles"), overwrite=False) @@ -2232,11 +2252,8 @@ def set_caption(self, caption: str | tuple) -> Styler: """ msg = "`caption` must be either a string or 2-tuple of strings." if isinstance(caption, tuple): - if ( - len(caption) != 2 - or not isinstance(caption[0], str) - or not isinstance(caption[1], str) - ): + if (len(caption) != 2 or not isinstance(caption[0], str) + or not isinstance(caption[1], str)): raise ValueError(msg) elif not isinstance(caption, str): raise ValueError(msg) @@ -2280,7 +2297,8 @@ def set_sticky( """ axis = self.data._get_axis_number(axis) obj = self.data.index if axis == 0 else self.data.columns - pixel_size = (75 if axis == 0 else 25) if not pixel_size else pixel_size + pixel_size = ( + 75 if axis == 0 else 25) if not pixel_size else pixel_size props = "position:sticky; background-color:white;" if not isinstance(obj, pd.MultiIndex): @@ -2289,23 +2307,20 @@ def set_sticky( if axis == 1: # stick the first of and, if index names, the second # if self._hide_columns then no here will exist: no conflict - styles: CSSStyles = [ - { - "selector": "thead tr:nth-child(1) th", - "props": props + "top:0px; z-index:2;", - } - ] + styles: CSSStyles = [{ + "selector": "thead tr:nth-child(1) th", + "props": props + "top:0px; z-index:2;", + }] if not self.index.names[0] is None: styles[0]["props"] = ( - props + f"top:0px; z-index:2; height:{pixel_size}px;" - ) - styles.append( - { - "selector": "thead tr:nth-child(2) th", - "props": props - + f"top:{pixel_size}px; z-index:2; height:{pixel_size}px; ", - } - ) + props + f"top:0px; z-index:2; height:{pixel_size}px;") + styles.append({ + "selector": + "thead tr:nth-child(2) th", + "props": + props + + f"top:{pixel_size}px; z-index:2; height:{pixel_size}px; ", + }) else: # stick the first of each in both and # if self._hide_index then no will exist in : no conflict @@ -2324,54 +2339,47 @@ def set_sticky( else: # handle the MultiIndex case range_idx = list(range(obj.nlevels)) - levels_: list[int] = refactor_levels(levels, obj) if levels else range_idx + levels_: list[int] = refactor_levels(levels, + obj) if levels else range_idx levels_ = sorted(levels_) if axis == 1: styles = [] for i, level in enumerate(levels_): - styles.append( - { - "selector": f"thead tr:nth-child({level+1}) th", - "props": props - + ( - f"top:{i * pixel_size}px; height:{pixel_size}px; " - "z-index:2;" - ), - } - ) + styles.append({ + "selector": + f"thead tr:nth-child({level+1}) th", + "props": + props + + (f"top:{i * pixel_size}px; height:{pixel_size}px; " + "z-index:2;"), + }) if not all(name is None for name in self.index.names): - styles.append( - { - "selector": f"thead tr:nth-child({obj.nlevels+1}) th", - "props": props - + ( - f"top:{(i+1) * pixel_size}px; height:{pixel_size}px; " - "z-index:2;" - ), - } - ) + styles.append({ + "selector": + f"thead tr:nth-child({obj.nlevels+1}) th", + "props": + props + + (f"top:{(i+1) * pixel_size}px; height:{pixel_size}px; " + "z-index:2;"), + }) else: styles = [] for i, level in enumerate(levels_): - props_ = props + ( - f"left:{i * pixel_size}px; " - f"min-width:{pixel_size}px; " - f"max-width:{pixel_size}px; " - ) - styles.extend( - [ - { - "selector": f"thead tr th:nth-child({level+1})", - "props": props_ + "z-index:3 !important;", - }, - { - "selector": f"tbody tr th.level{level}", - "props": props_ + "z-index:1;", - }, - ] - ) + props_ = props + (f"left:{i * pixel_size}px; " + f"min-width:{pixel_size}px; " + f"max-width:{pixel_size}px; ") + styles.extend([ + { + "selector": f"thead tr th:nth-child({level+1})", + "props": props_ + "z-index:3 !important;", + }, + { + "selector": f"tbody tr th.level{level}", + "props": props_ + "z-index:1;", + }, + ]) return self.set_table_styles(styles, overwrite=False) @@ -2499,23 +2507,17 @@ def set_table_styles( obj = self.data.index if axis == 1 else self.data.columns idf = f".{self.css['row']}" if axis == 1 else f".{self.css['col']}" - table_styles = [ - { - "selector": str(s["selector"]) + idf + str(idx), - "props": maybe_convert_css_to_tuples(s["props"]), - } - for key, styles in table_styles.items() + table_styles = [{ + "selector": str(s["selector"]) + idf + str(idx), + "props": maybe_convert_css_to_tuples(s["props"]), + } for key, styles in table_styles.items() for idx in obj.get_indexer_for([key]) - for s in format_table_styles(styles) - ] + for s in format_table_styles(styles)] else: - table_styles = [ - { - "selector": s["selector"], - "props": maybe_convert_css_to_tuples(s["props"]), - } - for s in table_styles - ] + table_styles = [{ + "selector": s["selector"], + "props": maybe_convert_css_to_tuples(s["props"]), + } for s in table_styles] if not overwrite and self.table_styles is not None: self.table_styles.extend(table_styles) @@ -2820,7 +2822,8 @@ def hide( obj, objs, alt = "column", "columns", "columns" if level is not None and subset is not None: - raise ValueError("`subset` and `level` cannot be passed simultaneously") + raise ValueError( + "`subset` and `level` cannot be passed simultaneously") if subset is None: if level is None and names: @@ -2839,9 +2842,11 @@ def hide( ) else: if axis == 0: - subset_ = IndexSlice[subset, :] # new var so mypy reads not Optional + subset_ = IndexSlice[ + subset, :] # new var so mypy reads not Optional else: - subset_ = IndexSlice[:, subset] # new var so mypy reads not Optional + subset_ = IndexSlice[:, + subset] # new var so mypy reads not Optional subset = non_reducing_slice(subset_) hide = self.data.loc[subset] h_els = getattr(self, objs).get_indexer_for(getattr(hide, objs)) @@ -3172,18 +3177,19 @@ def bar( elif color is not None and cmap is not None: raise ValueError("`color` and `cmap` cannot both be given") elif color is not None: - if (isinstance(color, (list, tuple)) and len(color) > 2) or not isinstance( - color, (str, list, tuple) - ): + if (isinstance(color, (list, tuple)) and + len(color) > 2) or not isinstance(color, + (str, list, tuple)): raise ValueError( "`color` must be string or list or tuple of 2 strings," - "(eg: color=['#d65f5f', '#5fba7d'])" - ) + "(eg: color=['#d65f5f', '#5fba7d'])") if not (0 <= width <= 100): - raise ValueError(f"`width` must be a value in [0, 100], got {width}") + raise ValueError( + f"`width` must be a value in [0, 100], got {width}") elif not (0 <= height <= 100): - raise ValueError(f"`height` must be a value in [0, 100], got {height}") + raise ValueError( + f"`height` must be a value in [0, 100], got {height}") if subset is None: subset = self.data.select_dtypes(include=np.number).columns @@ -3560,9 +3566,10 @@ def highlight_quantile( ) @classmethod - def from_custom_template( - cls, searchpath, html_table: str | None = None, html_style: str | None = None - ): + def from_custom_template(cls, + searchpath, + html_table: str | None = None, + html_style: str | None = None): """ Factory function for creating a subclass of ``Styler``. @@ -3590,7 +3597,8 @@ def from_custom_template( Has the correct ``env``,``template_html``, ``template_html_table`` and ``template_html_style`` class attributes set. """ - loader = jinja2.ChoiceLoader([jinja2.FileSystemLoader(searchpath), cls.loader]) + loader = jinja2.ChoiceLoader( + [jinja2.FileSystemLoader(searchpath), cls.loader]) # mypy doesn't like dynamically-defined classes # error: Variable "cls" is not valid as a type @@ -3753,13 +3761,11 @@ def _validate_apply_axis_arg( if isinstance(arg, Series) and isinstance(data, DataFrame): raise ValueError( f"'{arg_name}' is a Series but underlying data for operations " - f"is a DataFrame since 'axis=None'" - ) + f"is a DataFrame since 'axis=None'") elif isinstance(arg, DataFrame) and isinstance(data, Series): raise ValueError( f"'{arg_name}' is a DataFrame but underlying data for " - f"operations is a Series with 'axis in [0,1]'" - ) + f"operations is a Series with 'axis in [0,1]'") elif isinstance(arg, (Series, DataFrame)): # align indx / cols to data arg = arg.reindex_like(data, method=None).to_numpy(**dtype) else: @@ -3769,8 +3775,7 @@ def _validate_apply_axis_arg( raise ValueError( f"supplied '{arg_name}' is not correct shape for data over " f"selected 'axis': got {arg.shape}, " - f"expected {data.shape}" - ) + f"expected {data.shape}") return arg @@ -3817,10 +3822,8 @@ def relative_luminance(rgba) -> float: float The relative luminance as a value from 0 to 1 """ - r, g, b = ( - x / 12.92 if x <= 0.04045 else ((x + 0.055) / 1.055) ** 2.4 - for x in rgba[:3] - ) + r, g, b = (x / 12.92 if x <= 0.04045 else + ((x + 0.055) / 1.055)**2.4 for x in rgba[:3]) return 0.2126 * r + 0.7152 * g + 0.0722 * b def css(rgba, text_only) -> str: @@ -3871,32 +3874,27 @@ def _highlight_between( else: raise ValueError( f"'inclusive' values can be 'both', 'left', 'right', or 'neither' " - f"got {inclusive}" - ) + f"got {inclusive}") - g_left = ( - ops[0](data, left) - if left is not None - else np.full(data.shape, True, dtype=bool) - ) + g_left = (ops[0](data, left) + if left is not None else np.full(data.shape, True, dtype=bool)) if isinstance(g_left, (DataFrame, Series)): g_left = g_left.where(pd.notna(g_left), False) - l_right = ( - ops[1](data, right) - if right is not None - else np.full(data.shape, True, dtype=bool) - ) + l_right = (ops[1](data, right) + if right is not None else np.full(data.shape, True, dtype=bool)) if isinstance(l_right, (DataFrame, Series)): l_right = l_right.where(pd.notna(l_right), False) return np.where(g_left & l_right, props, "") -def _highlight_value(data: DataFrame | Series, op: str, props: str) -> np.ndarray: +def _highlight_value(data: DataFrame | Series, op: str, + props: str) -> np.ndarray: """ Return an array of css strings based on the condition of values matching an op. """ value = getattr(data, op)(skipna=True) - if isinstance(data, DataFrame): # min/max must be done twice to return scalar + if isinstance(data, + DataFrame): # min/max must be done twice to return scalar value = getattr(value, op)(skipna=True) cond = data == value cond = cond.where(pd.notna(cond), False) @@ -3969,7 +3967,8 @@ def css_bar(start: float, end: float, color: str) -> str: cell_css += f" {color} {end*100:.1f}%, transparent {end*100:.1f}%)" return cell_css - def css_calc(x, left: float, right: float, align: str, color: str | list | tuple): + def css_calc(x, left: float, right: float, align: str, + color: str | list | tuple): """ Return the correct CSS for bar placement based on calculated values. @@ -4027,9 +4026,8 @@ def css_calc(x, left: float, right: float, align: str, color: str | list | tuple elif align == "mid": # bars drawn from zero either leftwards or rightwards with center at mid mid: float = (left + right) / 2 - z_frac = ( - -mid / (right - left) + 0.5 if mid < 0 else -left / (right - left) - ) + z_frac = (-mid / (right - left) + 0.5 if mid < 0 else -left / + (right - left)) if x < 0: start, end = (x - left) / (right - left), z_frac @@ -4039,7 +4037,8 @@ def css_calc(x, left: float, right: float, align: str, color: str | list | tuple ret = css_bar(start * width, end * width, color) if height < 1 and "background: linear-gradient(" in ret: return ( - ret + f" no-repeat center; background-size: 100% {height * 100:.1f}%;" + ret + + f" no-repeat center; background-size: 100% {height * 100:.1f}%;" ) else: return ret @@ -4071,38 +4070,32 @@ def css_calc(x, left: float, right: float, align: str, color: str | list | tuple # use the matplotlib colormap input with _mpl(Styler.bar) as (plt, mpl): cmap = ( - mpl.cm.get_cmap(cmap) - if isinstance(cmap, str) - else cmap # assumed to be a Colormap instance as documented + mpl.cm.get_cmap(cmap) if isinstance(cmap, str) else + cmap # assumed to be a Colormap instance as documented ) norm = mpl.colors.Normalize(left, right) rgbas = cmap(norm(values)) if data.ndim == 1: rgbas = [mpl.colors.rgb2hex(rgba) for rgba in rgbas] else: - rgbas = [[mpl.colors.rgb2hex(rgba) for rgba in row] for row in rgbas] + rgbas = [[mpl.colors.rgb2hex(rgba) for rgba in row] + for row in rgbas] - assert isinstance(align, str) # mypy: should now be in [left, right, mid, zero] + assert isinstance(align, + str) # mypy: should now be in [left, right, mid, zero] if data.ndim == 1: return [ - css_calc( - x - z, left - z, right - z, align, colors if rgbas is None else rgbas[i] - ) + css_calc(x - z, left - z, right - z, align, + colors if rgbas is None else rgbas[i]) for i, x in enumerate(values) ] else: - return np.array( - [ - [ - css_calc( - x - z, - left - z, - right - z, - align, - colors if rgbas is None else rgbas[i][j], - ) - for j, x in enumerate(row) - ] - for i, row in enumerate(values) - ] - ) + return np.array([[ + css_calc( + x - z, + left - z, + right - z, + align, + colors if rgbas is None else rgbas[i][j], + ) for j, x in enumerate(row) + ] for i, row in enumerate(values)]) diff --git a/pandas/io/formats/style_render.py b/pandas/io/formats/style_render.py index 8dc75acb24330..9fbe7dcfdae66 100644 --- a/pandas/io/formats/style_render.py +++ b/pandas/io/formats/style_render.py @@ -43,7 +43,8 @@ from pandas.api.types import is_list_like import pandas.core.common as com -jinja2 = import_optional_dependency("jinja2", extra="DataFrame.style requires jinja2.") +jinja2 = import_optional_dependency("jinja2", + extra="DataFrame.style requires jinja2.") from markupsafe import escape as escape_html # markupsafe is jinja2 dependency BaseFormatter = Union[str, Callable] @@ -73,7 +74,8 @@ class StylerRenderer: Base class to process rendering a Styler with a specified jinja2 template. """ - loader = _gl01_adjust(jinja2.PackageLoader("pandas", "io/formats/templates")) + loader = _gl01_adjust( + jinja2.PackageLoader("pandas", "io/formats/templates")) env = _gl01_adjust(jinja2.Environment(loader=loader, trim_blocks=True)) template_html = _gl01_adjust(env.get_template("html.tpl")) template_html_table = _gl01_adjust(env.get_template("html_table.tpl")) @@ -102,8 +104,9 @@ def __init__( self.index: Index = data.index self.columns: Index = data.columns if not isinstance(uuid_len, int) or not uuid_len >= 0: - raise TypeError("``uuid_len`` must be an integer in range [0, 32].") - self.uuid = uuid or uuid4().hex[: min(32, uuid_len)] + raise TypeError( + "``uuid_len`` must be an integer in range [0, 32].") + self.uuid = uuid or uuid4().hex[:min(32, uuid_len)] self.uuid_len = len(self.uuid) self.table_styles = table_styles self.table_attributes = table_attributes @@ -128,26 +131,28 @@ def __init__( self.hide_column_names: bool = False self.hide_index_: list = [False] * self.index.nlevels self.hide_columns_: list = [False] * self.columns.nlevels - self.hidden_rows: Sequence[int] = [] # sequence for specific hidden rows/cols + self.hidden_rows: Sequence[int] = [ + ] # sequence for specific hidden rows/cols self.hidden_columns: Sequence[int] = [] self.ctx: DefaultDict[tuple[int, int], CSSList] = defaultdict(list) - self.ctx_index: DefaultDict[tuple[int, int], CSSList] = defaultdict(list) - self.ctx_columns: DefaultDict[tuple[int, int], CSSList] = defaultdict(list) + self.ctx_index: DefaultDict[tuple[int, int], + CSSList] = defaultdict(list) + self.ctx_columns: DefaultDict[tuple[int, int], + CSSList] = defaultdict(list) self.cell_context: DefaultDict[tuple[int, int], str] = defaultdict(str) self._todo: list[tuple[Callable, tuple, dict]] = [] self.tooltips: Tooltips | None = None - precision = ( - get_option("styler.format.precision") if precision is None else precision - ) + precision = (get_option("styler.format.precision") + if precision is None else precision) self._display_funcs: DefaultDict[ # maps (row, col) -> format func - tuple[int, int], Callable[[Any], str] - ] = defaultdict(lambda: partial(_default_formatter, precision=precision)) + tuple[int, int], Callable[[Any], str]] = defaultdict( + lambda: partial(_default_formatter, precision=precision)) self._display_funcs_index: DefaultDict[ # maps (row, level) -> format func - tuple[int, int], Callable[[Any], str] - ] = defaultdict(lambda: partial(_default_formatter, precision=precision)) + tuple[int, int], Callable[[Any], str]] = defaultdict( + lambda: partial(_default_formatter, precision=precision)) self._display_funcs_columns: DefaultDict[ # maps (level, col) -> format func - tuple[int, int], Callable[[Any], str] - ] = defaultdict(lambda: partial(_default_formatter, precision=precision)) + tuple[int, int], Callable[[Any], str]] = defaultdict( + lambda: partial(_default_formatter, precision=precision)) def _render( self, @@ -175,16 +180,16 @@ def _render( "row": f"{self.css['foot']}_{self.css['row']}", "foot": self.css["foot"], } - dx = self.concatenated._render( - sparse_index, sparse_columns, max_rows, max_cols, blank - ) + dx = self.concatenated._render(sparse_index, sparse_columns, + max_rows, max_cols, blank) for (r, c), v in self.concatenated.ctx.items(): self.ctx[(r + len(self.index), c)] = v for (r, c), v in self.concatenated.ctx_index.items(): self.ctx_index[(r + len(self.index), c)] = v - d = self._translate(sparse_index, sparse_columns, max_rows, max_cols, blank, dx) + d = self._translate(sparse_index, sparse_columns, max_rows, max_cols, + blank, dx) return d def _render_html( @@ -199,7 +204,8 @@ def _render_html( Renders the ``Styler`` including all applied styles to HTML. Generates a dict with necessary kwargs passed to jinja2 template. """ - d = self._render(sparse_index, sparse_columns, max_rows, max_cols, " ") + d = self._render(sparse_index, sparse_columns, max_rows, max_cols, + " ") d.update(kwargs) return self.template_html.render( **d, @@ -207,9 +213,8 @@ def _render_html( html_style_tpl=self.template_html_style, ) - def _render_latex( - self, sparse_index: bool, sparse_columns: bool, clines: str | None, **kwargs - ) -> str: + def _render_latex(self, sparse_index: bool, sparse_columns: bool, + clines: str | None, **kwargs) -> str: """ Render a Styler in latex format """ @@ -300,8 +305,10 @@ def _translate( } max_elements = get_option("styler.render.max_elements") - max_rows = max_rows if max_rows else get_option("styler.render.max_rows") - max_cols = max_cols if max_cols else get_option("styler.render.max_columns") + max_rows = max_rows if max_rows else get_option( + "styler.render.max_rows") + max_cols = max_cols if max_cols else get_option( + "styler.render.max_columns") max_rows, max_cols = _get_trimming_maximums( len(self.data.index), len(self.data.columns), @@ -310,24 +317,20 @@ def _translate( max_cols, ) - self.cellstyle_map_columns: DefaultDict[ - tuple[CSSPair, ...], list[str] - ] = defaultdict(list) + self.cellstyle_map_columns: DefaultDict[tuple[CSSPair, ...], + list[str]] = defaultdict(list) head = self._translate_header(sparse_cols, max_cols) d.update({"head": head}) # for sparsifying a MultiIndex and for use with latex clines - idx_lengths = _get_level_lengths( - self.index, sparse_index, max_rows, self.hidden_rows - ) + idx_lengths = _get_level_lengths(self.index, sparse_index, max_rows, + self.hidden_rows) d.update({"index_lengths": idx_lengths}) - self.cellstyle_map: DefaultDict[tuple[CSSPair, ...], list[str]] = defaultdict( - list - ) - self.cellstyle_map_index: DefaultDict[ - tuple[CSSPair, ...], list[str] - ] = defaultdict(list) + self.cellstyle_map: DefaultDict[tuple[CSSPair, ...], + list[str]] = defaultdict(list) + self.cellstyle_map_index: DefaultDict[tuple[CSSPair, ...], + list[str]] = defaultdict(list) body: list = self._translate_body(idx_lengths, max_rows, max_cols) d.update({"body": body}) @@ -337,22 +340,24 @@ def _translate( "cellstyle_columns": "cellstyle_map_columns", } # add the cell_ids styles map to the render dictionary in right format for k, attr in ctx_maps.items(): - map = [ - {"props": list(props), "selectors": selectors} - for props, selectors in getattr(self, attr).items() - ] + map = [{ + "props": list(props), + "selectors": selectors + } for props, selectors in getattr(self, attr).items()] d.update({k: map}) if dx is not None: # self.concatenated is not None d["body"].extend(dx["body"]) # type: ignore[union-attr] d["cellstyle"].extend(dx["cellstyle"]) # type: ignore[union-attr] - d["cellstyle_index"].extend(dx["cellstyle"]) # type: ignore[union-attr] + d["cellstyle_index"].extend( + dx["cellstyle"]) # type: ignore[union-attr] table_attr = self.table_attributes if not get_option("styler.html.mathjax"): table_attr = table_attr or "" if 'class="' in table_attr: - table_attr = table_attr.replace('class="', 'class="tex2jax_ignore ') + table_attr = table_attr.replace('class="', + 'class="tex2jax_ignore ') else: table_attr += ' class="tex2jax_ignore"' d.update({"table_attributes": table_attr}) @@ -388,9 +393,8 @@ def _translate_header(self, sparsify_cols: bool, max_cols: int): The associated HTML elements needed for template rendering. """ # for sparsifying a MultiIndex - col_lengths = _get_level_lengths( - self.columns, sparsify_cols, max_cols, self.hidden_columns - ) + col_lengths = _get_level_lengths(self.columns, sparsify_cols, max_cols, + self.hidden_columns) clabels = self.data.columns.tolist() if self.data.columns.nlevels == 1: @@ -404,25 +408,20 @@ def _translate_header(self, sparsify_cols: bool, max_cols: int): continue else: header_row = self._generate_col_header_row( - (r, clabels), max_cols, col_lengths - ) + (r, clabels), max_cols, col_lengths) head.append(header_row) # 2) index names - if ( - self.data.index.names - and com.any_not_none(*self.data.index.names) - and not all(self.hide_index_) - and not self.hide_index_names - ): + if (self.data.index.names and com.any_not_none(*self.data.index.names) + and not all(self.hide_index_) and not self.hide_index_names): index_names_row = self._generate_index_names_row( - clabels, max_cols, col_lengths - ) + clabels, max_cols, col_lengths) head.append(index_names_row) return head - def _generate_col_header_row(self, iter: tuple, max_cols: int, col_lengths: dict): + def _generate_col_header_row(self, iter: tuple, max_cols: int, + col_lengths: dict): """ Generate the row containing column headers: @@ -455,14 +454,10 @@ def _generate_col_header_row(self, iter: tuple, max_cols: int, col_lengths: dict column_name = [ _element( "th", - ( - f"{self.css['blank']} {self.css['level']}{r}" - if name is None - else f"{self.css['index_name']} {self.css['level']}{r}" - ), - name - if (name is not None and not self.hide_column_names) - else self.css["blank_value"], + (f"{self.css['blank']} {self.css['level']}{r}" if name is None + else f"{self.css['index_name']} {self.css['level']}{r}"), + name if (name is not None and not self.hide_column_names) else + self.css["blank_value"], not all(self.hide_index_), ) ] @@ -474,48 +469,43 @@ def _generate_col_header_row(self, iter: tuple, max_cols: int, col_lengths: dict if header_element_visible: visible_col_count += col_lengths.get((r, c), 0) if self._check_trim( - visible_col_count, - max_cols, - column_headers, - "th", - f"{self.css['col_heading']} {self.css['level']}{r} " - f"{self.css['col_trim']}", + visible_col_count, + max_cols, + column_headers, + "th", + f"{self.css['col_heading']} {self.css['level']}{r} " + f"{self.css['col_trim']}", ): break header_element = _element( "th", - ( - f"{self.css['col_heading']} {self.css['level']}{r} " - f"{self.css['col']}{c}" - ), + (f"{self.css['col_heading']} {self.css['level']}{r} " + f"{self.css['col']}{c}"), value, header_element_visible, display_value=self._display_funcs_columns[(r, c)](value), - attributes=( - f'colspan="{col_lengths.get((r, c), 0)}"' - if col_lengths.get((r, c), 0) > 1 - else "" - ), + attributes=(f'colspan="{col_lengths.get((r, c), 0)}"' + if col_lengths.get((r, c), 0) > 1 else ""), ) if self.cell_ids: - header_element["id"] = f"{self.css['level']}{r}_{self.css['col']}{c}" - if ( - header_element_visible - and (r, c) in self.ctx_columns - and self.ctx_columns[r, c] - ): - header_element["id"] = f"{self.css['level']}{r}_{self.css['col']}{c}" - self.cellstyle_map_columns[tuple(self.ctx_columns[r, c])].append( - f"{self.css['level']}{r}_{self.css['col']}{c}" - ) + header_element[ + "id"] = f"{self.css['level']}{r}_{self.css['col']}{c}" + if (header_element_visible and (r, c) in self.ctx_columns + and self.ctx_columns[r, c]): + header_element[ + "id"] = f"{self.css['level']}{r}_{self.css['col']}{c}" + self.cellstyle_map_columns[tuple(self.ctx_columns[ + r, + c])].append(f"{self.css['level']}{r}_{self.css['col']}{c}") column_headers.append(header_element) return index_blanks + column_name + column_headers - def _generate_index_names_row(self, iter: tuple, max_cols: int, col_lengths: dict): + def _generate_index_names_row(self, iter: tuple, max_cols: int, + col_lengths: dict): """ Generate the row containing index names @@ -543,8 +533,7 @@ def _generate_index_names_row(self, iter: tuple, max_cols: int, col_lengths: dic f"{self.css['index_name']} {self.css['level']}{c}", self.css["blank_value"] if name is None else name, not self.hide_index_[c], - ) - for c, name in enumerate(self.data.index.names) + ) for c, name in enumerate(self.data.index.names) ] column_blanks: list = [] @@ -552,16 +541,17 @@ def _generate_index_names_row(self, iter: tuple, max_cols: int, col_lengths: dic if clabels: last_level = self.columns.nlevels - 1 # use last level since never sparsed for c, value in enumerate(clabels[last_level]): - header_element_visible = _is_visible(c, last_level, col_lengths) + header_element_visible = _is_visible(c, last_level, + col_lengths) if header_element_visible: visible_col_count += 1 if self._check_trim( - visible_col_count, - max_cols, - column_blanks, - "th", - f"{self.css['blank']} {self.css['col']}{c} {self.css['col_trim']}", - self.css["blank_value"], + visible_col_count, + max_cols, + column_blanks, + "th", + f"{self.css['blank']} {self.css['col']}{c} {self.css['col_trim']}", + self.css["blank_value"], ): break @@ -571,8 +561,7 @@ def _generate_index_names_row(self, iter: tuple, max_cols: int, col_lengths: dic f"{self.css['blank']} {self.css['col']}{c}", self.css["blank_value"], c not in self.hidden_columns, - ) - ) + )) return index_names + column_blanks @@ -605,20 +594,20 @@ def _translate_body(self, idx_lengths: dict, max_rows: int, max_cols: int): body: list = [] visible_row_count: int = 0 for r, row_tup in [ - z for z in enumerate(self.data.itertuples()) if z[0] not in self.hidden_rows + z for z in enumerate(self.data.itertuples()) + if z[0] not in self.hidden_rows ]: visible_row_count += 1 if self._check_trim( - visible_row_count, - max_rows, - body, - "row", + visible_row_count, + max_rows, + body, + "row", ): break - body_row = self._generate_body_row( - (r, row_tup, rlabels), max_cols, idx_lengths - ) + body_row = self._generate_body_row((r, row_tup, rlabels), max_cols, + idx_lengths) body.append(body_row) return body @@ -678,15 +667,12 @@ def _generate_trimmed_row(self, max_cols: int) -> list: index_headers = [ _element( "th", - ( - f"{self.css['row_heading']} {self.css['level']}{c} " - f"{self.css['row_trim']}" - ), + (f"{self.css['row_heading']} {self.css['level']}{c} " + f"{self.css['row_trim']}"), "...", not self.hide_index_[c], attributes="", - ) - for c in range(self.data.index.nlevels) + ) for c in range(self.data.index.nlevels) ] data: list = [] @@ -696,11 +682,11 @@ def _generate_trimmed_row(self, max_cols: int) -> list: if data_element_visible: visible_col_count += 1 if self._check_trim( - visible_col_count, - max_cols, - data, - "td", - f"{self.css['data']} {self.css['row_trim']} {self.css['col_trim']}", + visible_col_count, + max_cols, + data, + "td", + f"{self.css['data']} {self.css['row_trim']} {self.css['col_trim']}", ): break @@ -711,8 +697,7 @@ def _generate_trimmed_row(self, max_cols: int) -> list: "...", data_element_visible, attributes="", - ) - ) + )) return index_headers + data @@ -746,56 +731,45 @@ def _generate_body_row( index_headers = [] for c, value in enumerate(rlabels[r]): - header_element_visible = ( - _is_visible(r, c, idx_lengths) and not self.hide_index_[c] - ) + header_element_visible = (_is_visible(r, c, idx_lengths) + and not self.hide_index_[c]) header_element = _element( "th", - ( - f"{self.css['row_heading']} {self.css['level']}{c} " - f"{self.css['row']}{r}" - ), + (f"{self.css['row_heading']} {self.css['level']}{c} " + f"{self.css['row']}{r}"), value, header_element_visible, display_value=self._display_funcs_index[(r, c)](value), - attributes=( - f'rowspan="{idx_lengths.get((c, r), 0)}"' - if idx_lengths.get((c, r), 0) > 1 - else "" - ), + attributes=(f'rowspan="{idx_lengths.get((c, r), 0)}"' + if idx_lengths.get((c, r), 0) > 1 else ""), ) if self.cell_ids: header_element[ - "id" - ] = f"{self.css['level']}{c}_{self.css['row']}{r}" # id is given - if ( - header_element_visible - and (r, c) in self.ctx_index - and self.ctx_index[r, c] - ): + "id"] = f"{self.css['level']}{c}_{self.css['row']}{r}" # id is given + if (header_element_visible and (r, c) in self.ctx_index + and self.ctx_index[r, c]): # always add id if a style is specified - header_element["id"] = f"{self.css['level']}{c}_{self.css['row']}{r}" + header_element[ + "id"] = f"{self.css['level']}{c}_{self.css['row']}{r}" self.cellstyle_map_index[tuple(self.ctx_index[r, c])].append( - f"{self.css['level']}{c}_{self.css['row']}{r}" - ) + f"{self.css['level']}{c}_{self.css['row']}{r}") index_headers.append(header_element) data: list = [] visible_col_count: int = 0 for c, value in enumerate(row_tup[1:]): - data_element_visible = ( - c not in self.hidden_columns and r not in self.hidden_rows - ) + data_element_visible = (c not in self.hidden_columns + and r not in self.hidden_rows) if data_element_visible: visible_col_count += 1 if self._check_trim( - visible_col_count, - max_cols, - data, - "td", - f"{self.css['data']} {self.css['row']}{r} {self.css['col_trim']}", + visible_col_count, + max_cols, + data, + "td", + f"{self.css['data']} {self.css['row']}{r} {self.css['col_trim']}", ): break @@ -806,10 +780,8 @@ def _generate_body_row( data_element = _element( "td", - ( - f"{self.css['data']} {self.css['row']}{r} " - f"{self.css['col']}{c}{cls}" - ), + (f"{self.css['data']} {self.css['row']}{r} " + f"{self.css['col']}{c}{cls}"), value, data_element_visible, attributes="", @@ -817,13 +789,14 @@ def _generate_body_row( ) if self.cell_ids: - data_element["id"] = f"{self.css['row']}{r}_{self.css['col']}{c}" + data_element[ + "id"] = f"{self.css['row']}{r}_{self.css['col']}{c}" if data_element_visible and (r, c) in self.ctx and self.ctx[r, c]: # always add id if needed due to specified style - data_element["id"] = f"{self.css['row']}{r}_{self.css['col']}{c}" + data_element[ + "id"] = f"{self.css['row']}{r}_{self.css['col']}{c}" self.cellstyle_map[tuple(self.ctx[r, c])].append( - f"{self.css['row']}{r}_{self.css['col']}{c}" - ) + f"{self.css['row']}{r}_{self.css['col']}{c}") data.append(data_element) @@ -841,29 +814,23 @@ def _translate_latex(self, d: dict, clines: str | None) -> None: """ index_levels = self.index.nlevels visible_index_level_n = index_levels - sum(self.hide_index_) - d["head"] = [ - [ - {**col, "cellstyle": self.ctx_columns[r, c - visible_index_level_n]} - for c, col in enumerate(row) - if col["is_visible"] - ] - for r, row in enumerate(d["head"]) - ] + d["head"] = [[{ + **col, "cellstyle": + self.ctx_columns[r, c - visible_index_level_n] + } for c, col in enumerate(row) if col["is_visible"]] + for r, row in enumerate(d["head"])] def concatenated_visible_rows(obj, n, row_indices): """ Extract all visible row indices recursively from concatenated stylers. """ - row_indices.extend( - [r + n for r in range(len(obj.index)) if r not in obj.hidden_rows] - ) - return ( - row_indices - if obj.concatenated is None - else concatenated_visible_rows( - obj.concatenated, n + len(obj.index), row_indices - ) - ) + row_indices.extend([ + r + n for r in range(len(obj.index)) + if r not in obj.hidden_rows + ]) + return (row_indices if obj.concatenated is None else + concatenated_visible_rows(obj.concatenated, n + + len(obj.index), row_indices)) body = [] for r, row in zip(concatenated_visible_rows(self, 0, []), d["body"]): @@ -876,20 +843,18 @@ def concatenated_visible_rows(obj, n, row_indices): row_body_headers = [ { **col, - "display_value": col["display_value"] - if col["is_visible"] - else "", - "cellstyle": self.ctx_index[r, c], - } - for c, col in enumerate(row[:index_levels]) + "display_value": + col["display_value"] if col["is_visible"] else "", + "cellstyle": + self.ctx_index[r, c], + } for c, col in enumerate(row[:index_levels]) if (col["type"] == "th" and not self.hide_index_[c]) ] - row_body_cells = [ - {**col, "cellstyle": self.ctx[r, c]} - for c, col in enumerate(row[index_levels:]) - if (col["is_visible"] and col["type"] == "td") - ] + row_body_cells = [{ + **col, "cellstyle": self.ctx[r, c] + } for c, col in enumerate(row[index_levels:]) + if (col["is_visible"] and col["type"] == "td")] body.append(row_body_headers + row_body_cells) d["body"] = body @@ -897,11 +862,11 @@ def concatenated_visible_rows(obj, n, row_indices): # clines are determined from info on index_lengths and hidden_rows and input # to a dict defining which row clines should be added in the template. if clines not in [ - None, - "all;data", - "all;index", - "skip-last;data", - "skip-last;index", + None, + "all;data", + "all;index", + "skip-last;data", + "skip-last;index", ]: raise ValueError( f"`clines` value of {clines} is invalid. Should either be None or one " @@ -912,7 +877,8 @@ def concatenated_visible_rows(obj, n, row_indices): d["clines"] = defaultdict(list) visible_row_indexes: list[int] = [ - r for r in range(len(self.data.index)) if r not in self.hidden_rows + r for r in range(len(self.data.index)) + if r not in self.hidden_rows ] visible_index_levels: list[int] = [ i for i in range(index_levels) if not self.hide_index_[i] @@ -1118,8 +1084,7 @@ def format( .. figure:: ../../_static/style/format_excel_css.png """ - if all( - ( + if all(( formatter is None, subset is None, precision is None, @@ -1128,8 +1093,7 @@ def format( na_rep is None, escape is None, hyperlinks is None, - ) - ): + )): self._display_funcs.clear() return self # clear the formatter / revert to default and avoid looping @@ -1306,8 +1270,7 @@ def format_index( display_funcs_, obj = self._display_funcs_columns, self.columns levels_ = refactor_levels(level, obj) - if all( - ( + if all(( formatter is None, level is None, precision is None, @@ -1316,8 +1279,7 @@ def format_index( na_rep is None, escape is None, hyperlinks is None, - ) - ): + )): display_funcs_.clear() return self # clear the formatter / revert to default and avoid looping @@ -1340,7 +1302,8 @@ def format_index( hyperlinks=hyperlinks, ) - for idx in [(i, lvl) if axis == 0 else (lvl, i) for i in range(len(obj))]: + for idx in [(i, lvl) if axis == 0 else (lvl, i) + for i in range(len(obj))]: display_funcs_[idx] = format_func return self @@ -1487,7 +1450,8 @@ def _get_level_lengths( lengths[(i, last_label)] += 1 non_zero_lengths = { - element: length for element, length in lengths.items() if length >= 1 + element: length + for element, length in lengths.items() if length >= 1 } return non_zero_lengths @@ -1507,11 +1471,10 @@ def format_table_styles(styles: CSSStyles) -> CSSStyles: ---> [{'selector': 'td', 'props': 'a:v;'}, {'selector': 'th', 'props': 'a:v;'}] """ - return [ - {"selector": selector, "props": css_dict["props"]} - for css_dict in styles - for selector in css_dict["selector"].split(",") - ] + return [{ + "selector": selector, + "props": css_dict["props"] + } for css_dict in styles for selector in css_dict["selector"].split(",")] def _default_formatter(x: Any, precision: int, thousands: bool = False) -> Any: @@ -1539,9 +1502,8 @@ def _default_formatter(x: Any, precision: int, thousands: bool = False) -> Any: return x -def _wrap_decimal_thousands( - formatter: Callable, decimal: str, thousands: str | None -) -> Callable: +def _wrap_decimal_thousands(formatter: Callable, decimal: str, + thousands: str | None) -> Callable: """ Takes a string formatting function and wraps logic to deal with thousands and decimal parameters, in the case that they are non-standard and that the input @@ -1551,12 +1513,9 @@ def _wrap_decimal_thousands( def wrapper(x): if is_float(x) or is_integer(x) or is_complex(x): if decimal != "." and thousands is not None and thousands != ",": - return ( - formatter(x) - .replace(",", "§_§-") # rare string to avoid "," <-> "." clash. - .replace(".", decimal) - .replace("§_§-", thousands) - ) + return (formatter(x).replace( + ",", "§_§-") # rare string to avoid "," <-> "." clash. + .replace(".", decimal).replace("§_§-", thousands)) elif decimal != "." and (thousands is None or thousands == ","): return formatter(x).replace(".", decimal) elif decimal == "." and thousands is not None and thousands != ",": @@ -1588,7 +1547,8 @@ def _render_href(x, format): elif format == "latex": href = r"\href{{{0}}}{{{0}}}" else: - raise ValueError("``hyperlinks`` format can only be 'html' or 'latex'") + raise ValueError( + "``hyperlinks`` format can only be 'html' or 'latex'") pat = r"(https?:\/\/|ftp:\/\/|www.)[\w/\-?=%.]+\.[\w/\-&?=%.]+" return re.sub(pat, lambda m: href.format(m.group(0)), x) return x @@ -1614,14 +1574,14 @@ def _maybe_wrap_formatter( elif callable(formatter): func_0 = formatter elif formatter is None: - precision = ( - get_option("styler.format.precision") if precision is None else precision - ) - func_0 = partial( - _default_formatter, precision=precision, thousands=(thousands is not None) - ) + precision = (get_option("styler.format.precision") + if precision is None else precision) + func_0 = partial(_default_formatter, + precision=precision, + thousands=(thousands is not None)) else: - raise TypeError(f"'formatter' expected str or callable, got {type(formatter)}") + raise TypeError( + f"'formatter' expected str or callable, got {type(formatter)}") # Replace chars if escaping if escape is not None: @@ -1631,7 +1591,9 @@ def _maybe_wrap_formatter( # Replace decimals and thousands if non-standard inputs detected if decimal != "." or (thousands is not None and thousands != ","): - func_2 = _wrap_decimal_thousands(func_1, decimal=decimal, thousands=thousands) + func_2 = _wrap_decimal_thousands(func_1, + decimal=decimal, + thousands=thousands) else: func_2 = func_1 @@ -1687,7 +1649,8 @@ def pred(part) -> bool: else: # error: Item "slice" of "Union[slice, Sequence[Any]]" has no attribute # "__iter__" (not iterable) -> is specifically list_like in conditional - slice_ = [p if pred(p) else [p] for p in slice_] # type: ignore[union-attr] + slice_ = [p if pred(p) else [p] + for p in slice_] # type: ignore[union-attr] return tuple(slice_) @@ -1700,16 +1663,12 @@ def maybe_convert_css_to_tuples(style: CSSProperties) -> CSSList: if isinstance(style, str): s = style.split(";") try: - return [ - (x.split(":")[0].strip(), x.split(":")[1].strip()) - for x in s - if x.strip() != "" - ] + return [(x.split(":")[0].strip(), x.split(":")[1].strip()) + for x in s if x.strip() != ""] except IndexError: raise ValueError( "Styles supplied as string must follow CSS rule formats, " - f"for example 'attr: val;'. '{style}' was given." - ) + f"for example 'attr: val;'. '{style}' was given.") return style @@ -1743,7 +1702,8 @@ def refactor_levels( for lev in level ] else: - raise ValueError("`level` must be of type `int`, `str` or list of such") + raise ValueError( + "`level` must be of type `int`, `str` or list of such") return levels_ @@ -1778,17 +1738,17 @@ class Tooltips: """ def __init__( - self, - css_props: CSSProperties = [ - ("visibility", "hidden"), - ("position", "absolute"), - ("z-index", 1), - ("background-color", "black"), - ("color", "white"), - ("transform", "translate(-20px, -20px)"), - ], - css_name: str = "pd-t", - tooltips: DataFrame = DataFrame(), + self, + css_props: CSSProperties = [ + ("visibility", "hidden"), + ("position", "absolute"), + ("z-index", 1), + ("background-color", "black"), + ("color", "white"), + ("transform", "translate(-20px, -20px)"), + ], + css_name: str = "pd-t", + tooltips: DataFrame = DataFrame(), ): self.class_name = css_name self.class_properties = css_props @@ -1806,12 +1766,10 @@ def _class_styles(self): ------- styles : List """ - return [ - { - "selector": f".{self.class_name}", - "props": maybe_convert_css_to_tuples(self.class_properties), - } - ] + return [{ + "selector": f".{self.class_name}", + "props": maybe_convert_css_to_tuples(self.class_properties), + }] def _pseudo_css(self, uuid: str, name: str, row: int, col: int, text: str): """ @@ -1882,20 +1840,17 @@ def _translate(self, styler: StylerRenderer, d: dict): return d name = self.class_name - mask = (self.tt_data.isna()) | (self.tt_data.eq("")) # empty string = no ttip + mask = (self.tt_data.isna()) | (self.tt_data.eq("") + ) # empty string = no ttip self.table_styles = [ - style - for sublist in [ - self._pseudo_css(styler.uuid, name, i, j, str(self.tt_data.iloc[i, j])) + style for sublist in [ + self._pseudo_css(styler.uuid, name, i, j, + str(self.tt_data.iloc[i, j])) for i in range(len(self.tt_data.index)) for j in range(len(self.tt_data.columns)) - if not ( - mask.iloc[i, j] - or i in styler.hidden_rows - or j in styler.hidden_columns - ) - ] - for style in sublist + if not (mask.iloc[i, j] or i in styler.hidden_rows + or j in styler.hidden_columns) + ] for style in sublist ] if self.table_styles: @@ -1904,16 +1859,16 @@ def _translate(self, styler: StylerRenderer, d: dict): for item in row: if item["type"] == "td": item["display_value"] = ( - str(item["display_value"]) - + f'' - ) + str(item["display_value"]) + + f'') d["table_styles"].extend(self._class_styles) d["table_styles"].extend(self.table_styles) return d -def _parse_latex_table_wrapping(table_styles: CSSStyles, caption: str | None) -> bool: +def _parse_latex_table_wrapping(table_styles: CSSStyles, + caption: str | None) -> bool: """ Indicate whether LaTeX {tabular} should be wrapped with a {table} environment. @@ -1923,13 +1878,13 @@ def _parse_latex_table_wrapping(table_styles: CSSStyles, caption: str | None) -> """ IGNORED_WRAPPERS = ["toprule", "midrule", "bottomrule", "column_format"] # ignored selectors are included with {tabular} so do not need wrapping - return ( - table_styles is not None - and any(d["selector"] not in IGNORED_WRAPPERS for d in table_styles) - ) or caption is not None + return (table_styles is not None + and any(d["selector"] not in IGNORED_WRAPPERS + for d in table_styles)) or caption is not None -def _parse_latex_table_styles(table_styles: CSSStyles, selector: str) -> str | None: +def _parse_latex_table_styles(table_styles: CSSStyles, + selector: str) -> str | None: """ Return the first 'props' 'value' from ``tables_styles`` identified by ``selector``. @@ -1946,15 +1901,16 @@ def _parse_latex_table_styles(table_styles: CSSStyles, selector: str) -> str | N The replacement of "§" with ":" is to avoid the CSS problem where ":" has structural significance and cannot be used in LaTeX labels, but is often required by them. """ - for style in table_styles[::-1]: # in reverse for most recently applied style + for style in table_styles[:: + -1]: # in reverse for most recently applied style if style["selector"] == selector: return str(style["props"][0][1]).replace("§", ":") return None -def _parse_latex_cell_styles( - latex_styles: CSSList, display_value: str, convert_css: bool = False -) -> str: +def _parse_latex_cell_styles(latex_styles: CSSList, + display_value: str, + convert_css: bool = False) -> str: r""" Mutate the ``display_value`` string including LaTeX commands from ``latex_styles``. @@ -1982,7 +1938,8 @@ def _parse_latex_cell_styles( """ if convert_css: latex_styles = _parse_latex_css_conversion(latex_styles) - for (command, options) in latex_styles[::-1]: # in reverse for most recent style + for (command, + options) in latex_styles[::-1]: # in reverse for most recent style formatter = { "--wrap": f"{{\\{command}--to_parse {display_value}}}", "--nowrap": f"\\{command}--to_parse {display_value}", @@ -1994,8 +1951,8 @@ def _parse_latex_cell_styles( for arg in ["--nowrap", "--wrap", "--lwrap", "--rwrap", "--dwrap"]: if arg in str(options): display_value = formatter[arg].replace( - "--to_parse", _parse_latex_options_strip(value=options, arg=arg) - ) + "--to_parse", + _parse_latex_options_strip(value=options, arg=arg)) break # only ever one purposeful entry return display_value @@ -2026,14 +1983,14 @@ def _parse_latex_header_span( >>> _parse_latex_header_span(cell, 't', 'c') '\\multicolumn{3}{c}{text}' """ - display_val = _parse_latex_cell_styles( - cell["cellstyle"], cell["display_value"], convert_css - ) + display_val = _parse_latex_cell_styles(cell["cellstyle"], + cell["display_value"], convert_css) if "attributes" in cell: attrs = cell["attributes"] if 'colspan="' in attrs: - colspan = attrs[attrs.find('colspan="') + 9 :] # len('colspan="') = 9 - colspan = int(colspan[: colspan.find('"')]) + colspan = attrs[attrs.find('colspan="') + + 9:] # len('colspan="') = 9 + colspan = int(colspan[:colspan.find('"')]) if "naive-l" == multicol_align: out = f"{{{display_val}}}" if wrap else f"{display_val}" blanks = " & {}" if wrap else " &" @@ -2046,8 +2003,8 @@ def _parse_latex_header_span( elif 'rowspan="' in attrs: if multirow_align == "naive": return display_val - rowspan = attrs[attrs.find('rowspan="') + 9 :] - rowspan = int(rowspan[: rowspan.find('"')]) + rowspan = attrs[attrs.find('rowspan="') + 9:] + rowspan = int(rowspan[:rowspan.find('"')]) return f"\\multirow[{multirow_align}]{{{rowspan}}}{{*}}{{{display_val}}}" if wrap: return f"{{{display_val}}}" @@ -2062,7 +2019,8 @@ def _parse_latex_options_strip(value: str | int | float, arg: str) -> str: For example: 'red /* --wrap */ ' --> 'red' """ - return str(value).replace(arg, "").replace("/*", "").replace("*/", "").strip() + return str(value).replace(arg, "").replace("/*", "").replace("*/", + "").strip() def _parse_latex_css_conversion(styles: CSSList) -> CSSList: @@ -2119,7 +2077,9 @@ def color(value, user_arg, command, comm_arg): CONVERTED_ATTRIBUTES: dict[str, Callable] = { "font-weight": font_weight, - "background-color": partial(color, command="cellcolor", comm_arg="--lwrap"), + "background-color": partial(color, + command="cellcolor", + comm_arg="--lwrap"), "color": partial(color, command="color", comm_arg=""), "font-style": font_style, } @@ -2158,19 +2118,18 @@ def _escape_latex(s): str : Escaped string """ - return ( - s.replace("\\", "ab2§=§8yz") # rare string for final conversion: avoid \\ clash - .replace("ab2§=§8yz ", "ab2§=§8yz\\space ") # since \backslash gobbles spaces - .replace("&", "\\&") - .replace("%", "\\%") - .replace("$", "\\$") - .replace("#", "\\#") - .replace("_", "\\_") - .replace("{", "\\{") - .replace("}", "\\}") - .replace("~ ", "~\\space ") # since \textasciitilde gobbles spaces - .replace("~", "\\textasciitilde ") - .replace("^ ", "^\\space ") # since \textasciicircum gobbles spaces - .replace("^", "\\textasciicircum ") - .replace("ab2§=§8yz", "\\textbackslash ") - ) + return (s.replace( + "\\", "ab2§=§8yz") # rare string for final conversion: avoid \\ clash + .replace("ab2§=§8yz ", + "ab2§=§8yz\\space ") # since \backslash gobbles spaces + .replace("&", "\\&").replace( + "%", "\\%").replace("$", "\\$").replace("#", "\\#").replace( + "_", + "\\_").replace("{", "\\{").replace("}", "\\}").replace( + "~ ", + "~\\space ") # since \textasciitilde gobbles spaces + .replace("~", "\\textasciitilde ").replace( + "^ ", "^\\space ") # since \textasciicircum gobbles spaces + .replace("^", + "\\textasciicircum ").replace("ab2§=§8yz", + "\\textbackslash ")) diff --git a/pandas/io/json/_json.py b/pandas/io/json/_json.py index c48fd2534fb71..025e050ebf601 100644 --- a/pandas/io/json/_json.py +++ b/pandas/io/json/_json.py @@ -88,8 +88,7 @@ def to_json( if not index and orient not in ["split", "table"]: raise ValueError( - "'index=False' is only valid when 'orient' is 'split' or 'table'" - ) + "'index=False' is only valid when 'orient' is 'split' or 'table'") if lines and orient != "records": raise ValueError("'lines' keyword only valid when 'orient' is records") @@ -124,9 +123,10 @@ def to_json( if path_or_buf is not None: # apply compression and byte/text conversion - with get_handle( - path_or_buf, "w", compression=compression, storage_options=storage_options - ) as handles: + with get_handle(path_or_buf, + "w", + compression=compression, + storage_options=storage_options) as handles: handles.handle.write(s) else: return s @@ -199,7 +199,8 @@ def obj_to_write(self) -> NDFrame | Mapping[IndexLabel, Any]: def _format_axes(self): if not self.obj.index.is_unique and self.orient == "index": - raise ValueError(f"Series index must be unique for orient='{self.orient}'") + raise ValueError( + f"Series index must be unique for orient='{self.orient}'") class FrameWriter(Writer): @@ -218,14 +219,14 @@ def _format_axes(self): """ Try to format axes if they are datelike. """ - if not self.obj.index.is_unique and self.orient in ("index", "columns"): + if not self.obj.index.is_unique and self.orient in ("index", + "columns"): raise ValueError( - f"DataFrame index must be unique for orient='{self.orient}'." - ) + f"DataFrame index must be unique for orient='{self.orient}'.") if not self.obj.columns.is_unique and self.orient in ( - "index", - "columns", - "records", + "index", + "columns", + "records", ): raise ValueError( f"DataFrame columns must be unique for orient='{self.orient}'." @@ -269,8 +270,7 @@ def __init__( msg = ( "Trying to write with `orient='table'` and " f"`date_format='{date_format}'`. Table Schema requires dates " - "to be formatted with `date_format='iso'`" - ) + "to be formatted with `date_format='iso'`") raise ValueError(msg) self.schema = build_table_schema(obj, index=self.index) @@ -278,15 +278,11 @@ def __init__( # NotImplemented on a column MultiIndex if obj.ndim == 2 and isinstance(obj.columns, MultiIndex): raise NotImplementedError( - "orient='table' is not supported for MultiIndex columns" - ) + "orient='table' is not supported for MultiIndex columns") # TODO: Do this timedelta properly in objToJSON.c See GH #15137 - if ( - (obj.ndim == 1) - and (obj.name in set(obj.index.names)) - or len(obj.columns.intersection(obj.index.names)) - ): + if ((obj.ndim == 1) and (obj.name in set(obj.index.names)) + or len(obj.columns.intersection(obj.index.names))): msg = "Overlapping names between the index and columns" raise ValueError(msg) @@ -314,12 +310,13 @@ def obj_to_write(self) -> NDFrame | Mapping[IndexLabel, Any]: @doc( storage_options=_shared_docs["storage_options"], - decompression_options=_shared_docs["decompression_options"] % "path_or_buf", + decompression_options=_shared_docs["decompression_options"] % + "path_or_buf", ) @deprecate_kwarg(old_arg_name="numpy", new_arg_name=None) -@deprecate_nonkeyword_arguments( - version="2.0", allowed_args=["path_or_buf"], stacklevel=3 -) +@deprecate_nonkeyword_arguments(version="2.0", + allowed_args=["path_or_buf"], + stacklevel=3) def read_json( path_or_buf=None, orient=None, @@ -701,12 +698,10 @@ def _get_data_from_filepath(self, filepath_or_buffer): """ # if it is a string but the file does not exist, it might be a JSON string filepath_or_buffer = stringify_path(filepath_or_buffer) - if ( - not isinstance(filepath_or_buffer, str) - or is_url(filepath_or_buffer) - or is_fsspec_url(filepath_or_buffer) - or file_exists(filepath_or_buffer) - ): + if (not isinstance(filepath_or_buffer, str) + or is_url(filepath_or_buffer) + or is_fsspec_url(filepath_or_buffer) + or file_exists(filepath_or_buffer)): self.handles = get_handle( filepath_or_buffer, "r", @@ -850,7 +845,8 @@ def __init__( if date_unit is not None: date_unit = date_unit.lower() if date_unit not in self._STAMP_UNITS: - raise ValueError(f"date_unit must be one of {self._STAMP_UNITS}") + raise ValueError( + f"date_unit must be one of {self._STAMP_UNITS}") self.min_stamp = self._MIN_STAMPS[date_unit] else: self.min_stamp = self._MIN_STAMPS["s"] @@ -870,7 +866,8 @@ def check_keys_split(self, decoded): bad_keys = set(decoded.keys()).difference(set(self._split_keys)) if bad_keys: bad_keys_joined = ", ".join(bad_keys) - raise ValueError(f"JSON data had unexpected key(s): {bad_keys_joined}") + raise ValueError( + f"JSON data had unexpected key(s): {bad_keys_joined}") def parse(self): @@ -911,9 +908,11 @@ def _convert_axes(self): def _try_convert_types(self): raise AbstractMethodError(self) - def _try_convert_data( - self, name, data, use_dtypes: bool = True, convert_dates: bool = True - ): + def _try_convert_data(self, + name, + data, + use_dtypes: bool = True, + convert_dates: bool = True): """ Try to parse a ndarray like into a column by inferring dtype. """ @@ -933,9 +932,8 @@ def _try_convert_data( pass else: # dtype to force - dtype = ( - self.dtype.get(name) if isinstance(self.dtype, dict) else self.dtype - ) + dtype = (self.dtype.get(name) + if isinstance(self.dtype, dict) else self.dtype) if dtype is not None: try: return data.astype(dtype), True @@ -1012,18 +1010,19 @@ def _try_convert_to_date(self, data): # ignore numbers that are out of range if issubclass(new_data.dtype.type, np.number): - in_range = ( - isna(new_data._values) - | (new_data > self.min_stamp) - | (new_data._values == iNaT) - ) + in_range = (isna(new_data._values) + | (new_data > self.min_stamp) + | (new_data._values == iNaT)) if not in_range.all(): return data, False - date_units = (self.date_unit,) if self.date_unit else self._STAMP_UNITS + date_units = ( + self.date_unit, ) if self.date_unit else self._STAMP_UNITS for date_unit in date_units: try: - new_data = to_datetime(new_data, errors="raise", unit=date_unit) + new_data = to_datetime(new_data, + errors="raise", + unit=date_unit) except (ValueError, OverflowError, TypeError): continue return new_data, True @@ -1045,7 +1044,8 @@ def _parse_no_numpy(self): self.check_keys_split(decoded) self.obj = create_series_with_explicit_dtype(**decoded) else: - self.obj = create_series_with_explicit_dtype(data, dtype_if_empty=object) + self.obj = create_series_with_explicit_dtype(data, + dtype_if_empty=object) def _parse_numpy(self): load_kwargs = { @@ -1066,17 +1066,17 @@ def _parse_numpy(self): # error: "create_series_with_explicit_dtype" # gets multiple values for keyword argument "dtype_if_empty self.obj = create_series_with_explicit_dtype( - *data, dtype_if_empty=object - ) # type: ignore[misc] + *data, dtype_if_empty=object) # type: ignore[misc] else: - self.obj = create_series_with_explicit_dtype(data, dtype_if_empty=object) + self.obj = create_series_with_explicit_dtype(data, + dtype_if_empty=object) def _try_convert_types(self): if self.obj is None: return - obj, result = self._try_convert_data( - "data", self.obj, convert_dates=self.convert_dates - ) + obj, result = self._try_convert_data("data", + self.obj, + convert_dates=self.convert_dates) if result: self.obj = obj @@ -1102,26 +1102,27 @@ def _parse_numpy(self): args = (args[0].T, args[2], args[1]) self.obj = DataFrame(*args) elif orient == "split": - decoded = loads( - json, dtype=None, numpy=True, precise_float=self.precise_float - ) + decoded = loads(json, + dtype=None, + numpy=True, + precise_float=self.precise_float) decoded = {str(k): v for k, v in decoded.items()} self.check_keys_split(decoded) self.obj = DataFrame(**decoded) elif orient == "values": self.obj = DataFrame( - loads(json, dtype=None, numpy=True, precise_float=self.precise_float) - ) + loads(json, + dtype=None, + numpy=True, + precise_float=self.precise_float)) else: - self.obj = DataFrame( - *loads( - json, - dtype=None, - numpy=True, - labelled=True, - precise_float=self.precise_float, - ) - ) + self.obj = DataFrame(*loads( + json, + dtype=None, + numpy=True, + labelled=True, + precise_float=self.precise_float, + )) def _parse_no_numpy(self): @@ -1129,13 +1130,13 @@ def _parse_no_numpy(self): orient = self.orient if orient == "columns": - self.obj = DataFrame( - loads(json, precise_float=self.precise_float), dtype=None - ) + self.obj = DataFrame(loads(json, precise_float=self.precise_float), + dtype=None) elif orient == "split": decoded = { str(k): v - for k, v in loads(json, precise_float=self.precise_float).items() + for k, v in loads(json, + precise_float=self.precise_float).items() } self.check_keys_split(decoded) self.obj = DataFrame(dtype=None, **decoded) @@ -1146,11 +1147,11 @@ def _parse_no_numpy(self): orient="index", ) elif orient == "table": - self.obj = parse_table_schema(json, precise_float=self.precise_float) + self.obj = parse_table_schema(json, + precise_float=self.precise_float) else: - self.obj = DataFrame( - loads(json, precise_float=self.precise_float), dtype=None - ) + self.obj = DataFrame(loads(json, precise_float=self.precise_float), + dtype=None) def _process_converter(self, f, filt=None): """ @@ -1186,8 +1187,7 @@ def _try_convert_types(self): self._try_convert_dates() self._process_converter( - lambda col, c: self._try_convert_data(col, c, convert_dates=False) - ) + lambda col, c: self._try_convert_data(col, c, convert_dates=False)) def _try_convert_dates(self): if self.obj is None: @@ -1207,20 +1207,15 @@ def is_ok(col) -> bool: return False col_lower = col.lower() - if ( - col_lower.endswith("_at") - or col_lower.endswith("_time") - or col_lower == "modified" - or col_lower == "date" - or col_lower == "datetime" - or col_lower.startswith("timestamp") - ): + if (col_lower.endswith("_at") or col_lower.endswith("_time") + or col_lower == "modified" or col_lower == "date" + or col_lower == "datetime" + or col_lower.startswith("timestamp")): return True return False self._process_converter( lambda col, c: self._try_convert_to_date(c), - lambda col, c: ( - (self.keep_default_dates and is_ok(col)) or col in convert_dates - ), + lambda col, c: + ((self.keep_default_dates and is_ok(col)) or col in convert_dates), ) diff --git a/pandas/io/parsers/python_parser.py b/pandas/io/parsers/python_parser.py index b762c5426540d..655a6a0f09709 100644 --- a/pandas/io/parsers/python_parser.py +++ b/pandas/io/parsers/python_parser.py @@ -56,6 +56,7 @@ class PythonParser(ParserBase): + def __init__(self, f: ReadCsvBuffer[str] | list, **kwds): """ Workhorse function for processing nested list into DataFrame @@ -140,9 +141,8 @@ def __init__(self, f: ReadCsvBuffer[str] | list, **kwds): # multiple date column thing turning into a real spaghetti factory if not self._has_complex_date_col: - (index_names, self.orig_names, self.columns) = self._get_index_name( - self.columns - ) + (index_names, self.orig_names, + self.columns) = self._get_index_name(self.columns) self._name_processed = True if self.index_names is None: self.index_names = index_names @@ -150,12 +150,12 @@ def __init__(self, f: ReadCsvBuffer[str] | list, **kwds): if self._col_indices is None: self._col_indices = list(range(len(self.columns))) - self._parse_date_cols = self._validate_parse_dates_presence(self.columns) + self._parse_date_cols = self._validate_parse_dates_presence( + self.columns) no_thousands_columns: set[int] | None = None if self.parse_dates: no_thousands_columns = self._set_noconvert_dtype_columns( - self._col_indices, self.columns - ) + self._col_indices, self.columns) self._no_thousands_columns = no_thousands_columns if len(self.decimal) != 1: @@ -166,10 +166,8 @@ def __init__(self, f: ReadCsvBuffer[str] | list, **kwds): regex = rf"^[\-\+]?[0-9]*({decimal}[0-9]*)?([0-9]?(E|e)\-?[0-9]+)?$" else: thousands = re.escape(self.thousands) - regex = ( - rf"^[\-\+]?([0-9]+{thousands}|[0-9])*({decimal}[0-9]*)?" - rf"([0-9]?(E|e)\-?[0-9]+)?$" - ) + regex = (rf"^[\-\+]?([0-9]+{thousands}|[0-9])*({decimal}[0-9]*)?" + rf"([0-9]?(E|e)\-?[0-9]+)?$") self.num = re.compile(regex) def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]) -> None: @@ -240,10 +238,10 @@ def _read(): self.data = reader # type: ignore[assignment] def read( - self, rows: int | None = None - ) -> tuple[ - Index | None, Sequence[Hashable] | MultiIndex, Mapping[Hashable, ArrayLike] - ]: + self, + rows: int | None = None + ) -> tuple[Index | None, Sequence[Hashable] | MultiIndex, Mapping[ + Hashable, ArrayLike]]: try: content = self._get_lines(rows) except StopIteration: @@ -267,7 +265,8 @@ def read( self.index_names, self.dtype, ) - conv_columns = self._maybe_make_multi_index_columns(columns, self.col_names) + conv_columns = self._maybe_make_multi_index_columns( + columns, self.col_names) return index, conv_columns, col_dict # handle new style for names in index @@ -283,9 +282,8 @@ def read( conv_data = self._convert_data(data) columns, conv_data = self._do_date_conversions(columns, conv_data) - index, result_columns = self._make_index( - conv_data, alldata, columns, indexnamerow - ) + index, result_columns = self._make_index(conv_data, alldata, columns, + indexnamerow) return index, result_columns, conv_data @@ -304,7 +302,8 @@ def _exclude_implicit_index( self._check_data_length(names, alldata) return { - name: alldata[i + offset] for i, name in enumerate(names) if i < len_alldata + name: alldata[i + offset] + for i, name in enumerate(names) if i < len_alldata }, names # legacy @@ -350,8 +349,7 @@ def _convert_data( ) def _infer_columns( - self, - ) -> tuple[list[list[Scalar | None]], int, set[Scalar | None]]: + self, ) -> tuple[list[list[Scalar | None]], int, set[Scalar | None]]: names = self.names num_original_columns = 0 clear_buffer = True @@ -382,8 +380,7 @@ def _infer_columns( if self.line_pos < hr: raise ValueError( f"Passed header={hr} but only {self.line_pos + 1} lines in " - "file" - ) from err + "file") from err # We have an empty file, so check # if columns are provided. That will @@ -395,7 +392,8 @@ def _infer_columns( return columns, num_original_columns, unnamed_cols if not self.names: - raise EmptyDataError("No columns to parse from file") from err + raise EmptyDataError( + "No columns to parse from file") from err line = self.names[:] @@ -419,8 +417,7 @@ def _infer_columns( # Ensure that regular columns are used before unnamed ones # to keep given names and mangle unnamed columns col_loop_order = [ - i - for i in range(len(this_columns)) + i for i in range(len(this_columns)) if i not in this_unnamed_cols ] + this_unnamed_cols @@ -438,13 +435,12 @@ def _infer_columns( else: cur_count = counts[col] - if ( - self.dtype is not None - and is_dict_like(self.dtype) - and self.dtype.get(old_col) is not None - and self.dtype.get(col) is None - ): - self.dtype.update({col: self.dtype.get(old_col)}) + if (self.dtype is not None + and is_dict_like(self.dtype) + and self.dtype.get(old_col) is not None + and self.dtype.get(col) is None): + self.dtype.update( + {col: self.dtype.get(old_col)}) this_columns[i] = col counts[col] = cur_count + 1 elif have_mi_columns: @@ -460,13 +456,16 @@ def _infer_columns( unnamed_count = len(this_unnamed_cols) # if wrong number of blanks or no index, not our format - if (lc != unnamed_count and lc - ic > unnamed_count) or ic == 0: + if (lc != unnamed_count + and lc - ic > unnamed_count) or ic == 0: clear_buffer = False this_columns = [None] * lc self.buf = [self.buf[-1]] columns.append(this_columns) - unnamed_cols.update({this_columns[i] for i in this_unnamed_cols}) + unnamed_cols.update( + {this_columns[i] + for i in this_unnamed_cols}) if len(columns) == 1: num_original_columns = len(this_columns) @@ -482,15 +481,16 @@ def _infer_columns( except StopIteration: first_line = None - len_first_data_row = 0 if first_line is None else len(first_line) + len_first_data_row = 0 if first_line is None else len( + first_line) - if len(names) > len(columns[0]) and len(names) > len_first_data_row: - raise ValueError( - "Number of passed names did not match " - "number of header fields in the file" - ) + if len(names) > len( + columns[0]) and len(names) > len_first_data_row: + raise ValueError("Number of passed names did not match " + "number of header fields in the file") if len(columns) > 1: - raise TypeError("Cannot pass names with multi-index columns") + raise TypeError( + "Cannot pass names with multi-index columns") if self.usecols is not None: # Set _use_cols. We don't store columns because they are @@ -499,22 +499,21 @@ def _infer_columns( else: num_original_columns = len(names) if self._col_indices is not None and len(names) != len( - self._col_indices - ): + self._col_indices): columns = [[names[i] for i in sorted(self._col_indices)]] else: columns = [names] else: - columns = self._handle_usecols( - columns, columns[0], num_original_columns - ) + columns = self._handle_usecols(columns, columns[0], + num_original_columns) else: try: line = self._buffered_line() except StopIteration as err: if not names: - raise EmptyDataError("No columns to parse from file") from err + raise EmptyDataError( + "No columns to parse from file") from err line = names[:] @@ -528,19 +527,19 @@ def _infer_columns( columns = [[f"{self.prefix}{i}" for i in range(ncols)]] else: columns = [list(range(ncols))] - columns = self._handle_usecols( - columns, columns[0], num_original_columns - ) + columns = self._handle_usecols(columns, columns[0], + num_original_columns) else: if self.usecols is None or len(names) >= num_original_columns: - columns = self._handle_usecols([names], names, num_original_columns) + columns = self._handle_usecols([names], names, + num_original_columns) num_original_columns = len(names) else: - if not callable(self.usecols) and len(names) != len(self.usecols): + if not callable( + self.usecols) and len(names) != len(self.usecols): raise ValueError( "Number of passed names did not match number of " - "header fields in the file" - ) + "header fields in the file") # Ignore output but set used columns. self._handle_usecols([names], names, ncols) columns = [names] @@ -566,8 +565,7 @@ def _handle_usecols( elif any(isinstance(u, str) for u in self.usecols): if len(columns) > 1: raise ValueError( - "If using multiple headers, usecols must be integers." - ) + "If using multiple headers, usecols must be integers.") col_indices = [] for col in self.usecols: @@ -575,7 +573,8 @@ def _handle_usecols( try: col_indices.append(usecols_key.index(col)) except ValueError: - self._validate_usecols_names(self.usecols, usecols_key) + self._validate_usecols_names( + self.usecols, usecols_key) else: col_indices.append(col) else: @@ -591,10 +590,8 @@ def _handle_usecols( ) col_indices = self.usecols - columns = [ - [n for i, n in enumerate(column) if i in col_indices] - for column in columns - ] + columns = [[n for i, n in enumerate(column) if i in col_indices] + for column in columns] self._col_indices = sorted(col_indices) return columns @@ -651,7 +648,7 @@ def _check_for_bom(self, first_row: list[Scalar]) -> list[Scalar]: # Extract any remaining data after the second # quotation mark. if len(first_row_bom) > end + 1: - new_row += first_row_bom[end + 1 :] + new_row += first_row_bom[end + 1:] else: @@ -688,9 +685,8 @@ def _next_line(self) -> list[Scalar]: line = self._check_comments([self.data[self.pos]])[0] self.pos += 1 # either uncommented or blank to begin with - if not self.skip_blank_lines and ( - self._is_line_empty(self.data[self.pos - 1]) or line - ): + if not self.skip_blank_lines and (self._is_line_empty( + self.data[self.pos - 1]) or line): break elif self.skip_blank_lines: ret = self._remove_empty_lines([line]) @@ -776,25 +772,24 @@ def _next_iter_line(self, row_num: int) -> list[Scalar] | None: assert isinstance(line, list) return line except csv.Error as e: - if self.on_bad_lines in (self.BadLineHandleMethod.ERROR, self.BadLineHandleMethod.WARN): + if self.on_bad_lines in ( + self.BadLineHandleMethod.ERROR, + self.BadLineHandleMethod.WARN, + ): msg = str(e) if "NULL byte" in msg or "line contains NUL" in msg: - msg = ( - "NULL byte detected. This byte " - "cannot be processed in Python's " - "native csv library at the moment, " - "so please pass in engine='c' instead" - ) + msg = ("NULL byte detected. This byte " + "cannot be processed in Python's " + "native csv library at the moment, " + "so please pass in engine='c' instead") if self.skipfooter > 0: - reason = ( - "Error could possibly be due to " - "parsing errors in the skipped footer rows " - "(the skipfooter keyword is only applied " - "after Python's csv library has parsed " - "all rows)." - ) + reason = ("Error could possibly be due to " + "parsing errors in the skipped footer rows " + "(the skipfooter keyword is only applied " + "after Python's csv library has parsed " + "all rows).") msg += ". " + reason self._alert_malformed(msg, row_num) @@ -807,21 +802,19 @@ def _check_comments(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: for line in lines: rl = [] for x in line: - if ( - not isinstance(x, str) - or self.comment not in x - or x in self.na_values - ): + if (not isinstance(x, str) or self.comment not in x + or x in self.na_values): rl.append(x) else: - x = x[: x.find(self.comment)] + x = x[:x.find(self.comment)] if len(x) > 0: rl.append(x) break ret.append(rl) return ret - def _remove_empty_lines(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: + def _remove_empty_lines(self, + lines: list[list[Scalar]]) -> list[list[Scalar]]: """ Iterate through the lines and remove any that are either empty or contain only one whitespace value @@ -839,35 +832,31 @@ def _remove_empty_lines(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: ret = [] for line in lines: # Remove empty lines and lines with only one whitespace value - if ( - len(line) > 1 - or len(line) == 1 - and (not isinstance(line[0], str) or line[0].strip()) - ): + if (len(line) > 1 or len(line) == 1 and + (not isinstance(line[0], str) or line[0].strip())): ret.append(line) return ret - def _check_thousands(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: + def _check_thousands(self, + lines: list[list[Scalar]]) -> list[list[Scalar]]: if self.thousands is None: return lines - return self._search_replace_num_columns( - lines=lines, search=self.thousands, replace="" - ) + return self._search_replace_num_columns(lines=lines, + search=self.thousands, + replace="") - def _search_replace_num_columns( - self, lines: list[list[Scalar]], search: str, replace: str - ) -> list[list[Scalar]]: + def _search_replace_num_columns(self, lines: list[list[Scalar]], + search: str, + replace: str) -> list[list[Scalar]]: ret = [] for line in lines: rl = [] for i, x in enumerate(line): - if ( - not isinstance(x, str) - or search not in x - or (self._no_thousands_columns and i in self._no_thousands_columns) - or not self.num.search(x.strip()) - ): + if (not isinstance(x, str) or search not in x + or (self._no_thousands_columns + and i in self._no_thousands_columns) + or not self.num.search(x.strip())): rl.append(x) else: rl.append(x.replace(search, replace)) @@ -878,9 +867,9 @@ def _check_decimal(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: if self.decimal == parser_defaults["decimal"]: return lines - return self._search_replace_num_columns( - lines=lines, search=self.decimal, replace="." - ) + return self._search_replace_num_columns(lines=lines, + search=self.decimal, + replace=".") def _clear_buffer(self) -> None: self.buf = [] @@ -954,9 +943,8 @@ def _get_index_name( else: # Case 2 - (index_name, _, self.index_col) = self._clean_index_names( - columns, self.index_col - ) + (index_name, _, + self.index_col) = self._clean_index_names(columns, self.index_col) return index_name, orig_names, columns @@ -973,11 +961,9 @@ def _rows_to_cols(self, content: list[list[Scalar]]) -> list[np.ndarray]: # elements are padded with NaN). # error: Non-overlapping identity check (left operand type: "List[int]", # right operand type: "Literal[False]") - if ( - max_len > col_len - and self.index_col is not False # type: ignore[comparison-overlap] - and self.usecols is None - ): + if (max_len > col_len and + self.index_col is not False # type: ignore[comparison-overlap] + and self.usecols is None): footers = self.skipfooter if self.skipfooter else 0 bad_lines = [] @@ -994,7 +980,10 @@ def _rows_to_cols(self, content: list[list[Scalar]]) -> list[np.ndarray]: new_l = self.on_bad_lines(l) if new_l is not None: content.append(new_l) - elif self.on_bad_lines in (self.BadLineHandleMethod.ERROR, self.BadLineHandleMethod.WARN): + elif self.on_bad_lines in ( + self.BadLineHandleMethod.ERROR, + self.BadLineHandleMethod.WARN, + ): row_num = self.pos - (content_len - i + footers) bad_lines.append((row_num, actual_len)) @@ -1004,26 +993,20 @@ def _rows_to_cols(self, content: list[list[Scalar]]) -> list[np.ndarray]: content.append(l) for row_num, actual_len in bad_lines: - msg = ( - f"Expected {col_len} fields in line {row_num + 1}, saw " - f"{actual_len}" - ) - if ( - self.delimiter - and len(self.delimiter) > 1 - and self.quoting != csv.QUOTE_NONE - ): + msg = (f"Expected {col_len} fields in line {row_num + 1}, saw " + f"{actual_len}") + if (self.delimiter and len(self.delimiter) > 1 + and self.quoting != csv.QUOTE_NONE): # see gh-13374 - reason = ( - "Error could possibly be due to quotes being " - "ignored when a multi-char delimiter is used." - ) + reason = ("Error could possibly be due to quotes being " + "ignored when a multi-char delimiter is used.") msg += ". " + reason self._alert_malformed(msg, row_num + 1) # see gh-13320 - zipped_content = list(lib.to_object_array(content, min_width=col_len).T) + zipped_content = list( + lib.to_object_array(content, min_width=col_len).T) if self.usecols: assert self._col_indices is not None @@ -1031,12 +1014,9 @@ def _rows_to_cols(self, content: list[list[Scalar]]) -> list[np.ndarray]: if self._implicit_index: zipped_content = [ - a - for i, a in enumerate(zipped_content) - if ( - i < len(self.index_col) - or i - len(self.index_col) in col_indices - ) + a for i, a in enumerate(zipped_content) + if (i < len(self.index_col) or i - + len(self.index_col) in col_indices) ] else: zipped_content = [ @@ -1063,10 +1043,10 @@ def _get_lines(self, rows: int | None = None) -> list[list[Scalar]]: if self.pos > len(self.data): raise StopIteration if rows is None: - new_rows = self.data[self.pos :] + new_rows = self.data[self.pos:] new_pos = len(self.data) else: - new_rows = self.data[self.pos : self.pos + rows] + new_rows = self.data[self.pos:self.pos + rows] new_pos = self.pos + rows new_rows = self._remove_skipped_rows(new_rows) @@ -1082,8 +1062,7 @@ def _get_lines(self, rows: int | None = None) -> list[list[Scalar]]: if self.skiprows is not None and self.pos is not None: # Only read additional rows if pos is in skiprows rows_to_skip = len( - set(self.skiprows) - set(range(self.pos)) - ) + set(self.skiprows) - set(range(self.pos))) for _ in range(rows + rows_to_skip): # assert for mypy, data is Iterator[str] or None, would @@ -1098,7 +1077,8 @@ def _get_lines(self, rows: int | None = None) -> list[list[Scalar]]: rows = 0 while True: - new_row = self._next_iter_line(row_num=self.pos + rows + 1) + new_row = self._next_iter_line(row_num=self.pos + + rows + 1) rows += 1 if new_row is not None: @@ -1118,7 +1098,7 @@ def _get_lines(self, rows: int | None = None) -> list[list[Scalar]]: lines = new_rows if self.skipfooter: - lines = lines[: -self.skipfooter] + lines = lines[:-self.skipfooter] lines = self._check_comments(lines) if self.skip_blank_lines: @@ -1126,10 +1106,12 @@ def _get_lines(self, rows: int | None = None) -> list[list[Scalar]]: lines = self._check_thousands(lines) return self._check_decimal(lines) - def _remove_skipped_rows(self, new_rows: list[list[Scalar]]) -> list[list[Scalar]]: + def _remove_skipped_rows( + self, new_rows: list[list[Scalar]]) -> list[list[Scalar]]: if self.skiprows: return [ - row for i, row in enumerate(new_rows) if not self.skipfunc(i + self.pos) + row for i, row in enumerate(new_rows) + if not self.skipfunc(i + self.pos) ] return new_rows @@ -1153,31 +1135,25 @@ def __init__( self.delimiter = "\r\n" + delimiter if delimiter else "\n\r\t " self.comment = comment if colspecs == "infer": - self.colspecs = self.detect_colspecs( - infer_nrows=infer_nrows, skiprows=skiprows - ) + self.colspecs = self.detect_colspecs(infer_nrows=infer_nrows, + skiprows=skiprows) else: self.colspecs = colspecs if not isinstance(self.colspecs, (tuple, list)): - raise TypeError( - "column specifications must be a list or tuple, " - f"input was a {type(colspecs).__name__}" - ) + raise TypeError("column specifications must be a list or tuple, " + f"input was a {type(colspecs).__name__}") for colspec in self.colspecs: - if not ( - isinstance(colspec, (tuple, list)) - and len(colspec) == 2 - and isinstance(colspec[0], (int, np.integer, type(None))) - and isinstance(colspec[1], (int, np.integer, type(None))) - ): - raise TypeError( - "Each column specification must be " - "2 element tuple or list of integers" - ) - - def get_rows(self, infer_nrows: int, skiprows: set[int] | None = None) -> list[str]: + if not (isinstance(colspec, (tuple, list)) and len(colspec) == 2 + and isinstance(colspec[0], (int, np.integer, type(None))) + and isinstance(colspec[1], (int, np.integer, type(None)))): + raise TypeError("Each column specification must be " + "2 element tuple or list of integers") + + def get_rows(self, + infer_nrows: int, + skiprows: set[int] | None = None) -> list[str]: """ Read rows from self.f, skipping as specified. @@ -1216,8 +1192,9 @@ def get_rows(self, infer_nrows: int, skiprows: set[int] | None = None) -> list[s return detect_rows def detect_colspecs( - self, infer_nrows: int = 100, skiprows: set[int] | None = None - ) -> list[tuple[int, int]]: + self, + infer_nrows: int = 100, + skiprows: set[int] | None = None) -> list[tuple[int, int]]: # Regex escape the delimiters delimiters = "".join([rf"\{x}" for x in self.delimiter]) pattern = re.compile(f"([^{delimiters}]+)") @@ -1230,7 +1207,7 @@ def detect_colspecs( rows = [row.partition(self.comment)[0] for row in rows] for row in rows: for m in pattern.finditer(row): - mask[m.start() : m.end()] = 1 + mask[m.start():m.end()] = 1 shifted = np.roll(mask, 1) shifted[0] = 0 edges = np.where((mask ^ shifted) == 1)[0] @@ -1249,7 +1226,10 @@ def __next__(self) -> list[str]: else: line = next(self.f) # type: ignore[arg-type] # Note: 'colspecs' is a sequence of half-open intervals. - return [line[fromm:to].strip(self.delimiter) for (fromm, to) in self.colspecs] + return [ + line[fromm:to].strip(self.delimiter) + for (fromm, to) in self.colspecs + ] class FixedWidthFieldParser(PythonParser): @@ -1274,7 +1254,8 @@ def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]) -> None: self.infer_nrows, ) - def _remove_empty_lines(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: + def _remove_empty_lines(self, + lines: list[list[Scalar]]) -> list[list[Scalar]]: """ Returns the list of lines without the empty ones. With fixed-width fields, empty lines become arrays of empty strings. @@ -1282,8 +1263,7 @@ def _remove_empty_lines(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: See PythonParser._remove_empty_lines. """ return [ - line - for line in lines + line for line in lines if any(not isinstance(e, str) or e.strip() for e in line) ] diff --git a/pandas/io/parsers/readers.py b/pandas/io/parsers/readers.py index e496ff36dd240..00e768047a02b 100644 --- a/pandas/io/parsers/readers.py +++ b/pandas/io/parsers/readers.py @@ -70,8 +70,7 @@ PythonParser, ) -_doc_read_csv_and_table = ( - r""" +_doc_read_csv_and_table = (r""" {summary} Also supports optionally iterating or breaking of the file @@ -200,9 +199,8 @@ na_values : scalar, str, list-like, or dict, optional Additional strings to recognize as NA/NaN. If dict passed, specific per-column NA values. By default the following values are interpreted as - NaN: '""" - + fill("', '".join(sorted(STR_NA_VALUES)), 70, subsequent_indent=" ") - + """'. + NaN: '""" + fill( + "', '".join(sorted(STR_NA_VALUES)), 70, subsequent_indent=" ") + """'. keep_default_na : bool, default True Whether or not to include the default NaN values when parsing the data. Depending on whether `na_values` is passed in, the behavior is as follows: @@ -427,9 +425,7 @@ Examples -------- >>> pd.{func_name}('data.csv') # doctest: +SKIP -""" -) - +""") _c_parser_defaults = { "delim_whitespace": False, @@ -475,14 +471,16 @@ class _DeprecationConfig(NamedTuple): _deprecated_defaults: dict[str, _DeprecationConfig] = { - "error_bad_lines": _DeprecationConfig(None, "Use on_bad_lines in the future."), - "warn_bad_lines": _DeprecationConfig(None, "Use on_bad_lines in the future."), - "squeeze": _DeprecationConfig( - None, 'Append .squeeze("columns") to the call to squeeze.' - ), - "prefix": _DeprecationConfig( - None, "Use a list comprehension on the column names in the future." - ), + "error_bad_lines": + _DeprecationConfig(None, "Use on_bad_lines in the future."), + "warn_bad_lines": + _DeprecationConfig(None, "Use on_bad_lines in the future."), + "squeeze": + _DeprecationConfig(None, + 'Append .squeeze("columns") to the call to squeeze.'), + "prefix": + _DeprecationConfig( + None, "Use a list comprehension on the column names in the future."), } @@ -533,15 +531,13 @@ def _validate_names(names): if names is not None: if len(names) != len(set(names)): raise ValueError("Duplicate names are not allowed.") - if not ( - is_list_like(names, allow_sets=False) or isinstance(names, abc.KeysView) - ): + if not (is_list_like(names, allow_sets=False) + or isinstance(names, abc.KeysView)): raise ValueError("Names should be an ordered collection.") -def _read( - filepath_or_buffer: FilePath | ReadCsvBuffer[bytes] | ReadCsvBuffer[str], kwds -) -> DataFrame | TextFileReader: +def _read(filepath_or_buffer: FilePath | ReadCsvBuffer[bytes] + | ReadCsvBuffer[str], kwds) -> DataFrame | TextFileReader: """Generic reader of line files.""" # if we pass a date_parser and parse_dates=False, we should not parse the # dates GH#44366 @@ -822,9 +818,9 @@ def read_csv( ... -@deprecate_nonkeyword_arguments( - version=None, allowed_args=["filepath_or_buffer"], stacklevel=3 -) +@deprecate_nonkeyword_arguments(version=None, + allowed_args=["filepath_or_buffer"], + stacklevel=3) @Appender( _doc_read_csv_and_table.format( func_name="read_csv", @@ -832,8 +828,7 @@ def read_csv( _default_sep="','", storage_options=_shared_docs["storage_options"], decompression_options=_shared_docs["decompression_options"], - ) -) + )) def read_csv( filepath_or_buffer: FilePath | ReadCsvBuffer[bytes] | ReadCsvBuffer[str], sep: str | None | lib.NoDefault = lib.no_default, @@ -1161,9 +1156,9 @@ def read_table( ... -@deprecate_nonkeyword_arguments( - version=None, allowed_args=["filepath_or_buffer"], stacklevel=3 -) +@deprecate_nonkeyword_arguments(version=None, + allowed_args=["filepath_or_buffer"], + stacklevel=3) @Appender( _doc_read_csv_and_table.format( func_name="read_table", @@ -1171,8 +1166,7 @@ def read_table( _default_sep=r"'\\t' (tab-stop)", storage_options=_shared_docs["storage_options"], decompression_options=_shared_docs["decompression_options"], - ) -) + )) def read_table( filepath_or_buffer: FilePath | ReadCsvBuffer[bytes] | ReadCsvBuffer[str], sep: str | None | lib.NoDefault = lib.no_default, @@ -1260,9 +1254,9 @@ def read_table( return _read(filepath_or_buffer, kwds) -@deprecate_nonkeyword_arguments( - version=None, allowed_args=["filepath_or_buffer"], stacklevel=2 -) +@deprecate_nonkeyword_arguments(version=None, + allowed_args=["filepath_or_buffer"], + stacklevel=2) def read_fwf( filepath_or_buffer: FilePath | ReadCsvBuffer[bytes] | ReadCsvBuffer[str], colspecs: Sequence[tuple[int, int]] | str | None = "infer", @@ -1321,7 +1315,8 @@ def read_fwf( if colspecs is None and widths is None: raise ValueError("Must specify either colspecs or widths") elif colspecs not in (None, "infer") and widths is not None: - raise ValueError("You must specify only one of 'widths' and 'colspecs'") + raise ValueError( + "You must specify only one of 'widths' and 'colspecs'") # Compute 'colspecs' from 'widths', if specified. if widths is not None: @@ -1349,7 +1344,8 @@ def read_fwf( else: len_index = len(index_col) if len(names) + len_index != len(colspecs): - raise ValueError("Length of colspecs must match length of names") + raise ValueError( + "Length of colspecs must match length of names") kwds["colspecs"] = colspecs kwds["infer_nrows"] = infer_nrows @@ -1428,29 +1424,23 @@ def _get_options_with_defaults(self, engine: CSVEngine) -> dict[str, Any]: value = kwds.get(argname, default) # see gh-12935 - if ( - engine == "pyarrow" - and argname in _pyarrow_unsupported - and value != default - and value != getattr(value, "value", default) - ): - if ( - argname == "on_bad_lines" - and kwds.get("error_bad_lines") is not None - ): + if (engine == "pyarrow" and argname in _pyarrow_unsupported + and value != default + and value != getattr(value, "value", default)): + if (argname == "on_bad_lines" + and kwds.get("error_bad_lines") is not None): argname = "error_bad_lines" - elif ( - argname == "on_bad_lines" and kwds.get("warn_bad_lines") is not None - ): + elif (argname == "on_bad_lines" + and kwds.get("warn_bad_lines") is not None): argname = "warn_bad_lines" raise ValueError( f"The {repr(argname)} option is not supported with the " - f"'pyarrow' engine" - ) + f"'pyarrow' engine") elif argname == "mangle_dupe_cols" and value is False: # GH12935 - raise ValueError("Setting mangle_dupe_cols=False is not supported yet") + raise ValueError( + "Setting mangle_dupe_cols=False is not supported yet") else: options[argname] = value @@ -1461,22 +1451,17 @@ def _get_options_with_defaults(self, engine: CSVEngine) -> dict[str, Any]: if engine != "c" and value != default: if "python" in engine and argname not in _python_unsupported: pass - elif ( - value - == _deprecated_defaults.get( - argname, _DeprecationConfig(default, None) - ).default_value - ): + elif (value == _deprecated_defaults.get( + argname, _DeprecationConfig(default, + None)).default_value): pass else: raise ValueError( f"The {repr(argname)} option is not supported with the " - f"{repr(engine)} engine" - ) + f"{repr(engine)} engine") else: value = _deprecated_defaults.get( - argname, _DeprecationConfig(default, None) - ).default_value + argname, _DeprecationConfig(default, None)).default_value options[argname] = value if engine == "python-fwf": @@ -1493,12 +1478,10 @@ def _check_file_or_buffer(self, f, engine: CSVEngine) -> None: # when iterating through such an object, meaning it # needs to have that attribute raise ValueError( - "The 'python' engine cannot iterate through this file buffer." - ) + "The 'python' engine cannot iterate through this file buffer.") - def _clean_options( - self, options: dict[str, Any], engine: CSVEngine - ) -> tuple[dict[str, Any], CSVEngine]: + def _clean_options(self, options: dict[str, Any], + engine: CSVEngine) -> tuple[dict[str, Any], CSVEngine]: result = options.copy() fallback_reason = None @@ -1514,10 +1497,8 @@ def _clean_options( if sep is None and not delim_whitespace: if engine in ("c", "pyarrow"): - fallback_reason = ( - f"the '{engine}' engine does not support " - "sep=None with delim_whitespace=False" - ) + fallback_reason = (f"the '{engine}' engine does not support " + "sep=None with delim_whitespace=False") engine = "python" elif sep is not None and len(sep) > 1: if engine == "c" and sep == r"\s+": @@ -1528,8 +1509,7 @@ def _clean_options( fallback_reason = ( f"the '{engine}' engine does not support " "regex separators (separators > 1 char and " - r"different from '\s+' are interpreted as regex)" - ) + r"different from '\s+' are interpreted as regex)") engine = "python" elif delim_whitespace: if "python" in engine: @@ -1546,17 +1526,13 @@ def _clean_options( fallback_reason = ( f"the separator encoded in {encoding} " f"is > 1 char long, and the '{engine}' engine " - "does not support such separators" - ) + "does not support such separators") engine = "python" quotechar = options["quotechar"] if quotechar is not None and isinstance(quotechar, (str, bytes)): - if ( - len(quotechar) == 1 - and ord(quotechar) > 127 - and engine not in ("python", "python-fwf") - ): + if (len(quotechar) == 1 and ord(quotechar) > 127 + and engine not in ("python", "python-fwf")): fallback_reason = ( "ord(quotechar) > 127, meaning the " "quotechar is larger than one byte, " @@ -1583,11 +1559,9 @@ def _clean_options( if fallback_reason: warnings.warn( - ( - "Falling back to the 'python' engine because " - f"{fallback_reason}; you can avoid this warning by specifying " - "engine='python'." - ), + ("Falling back to the 'python' engine because " + f"{fallback_reason}; you can avoid this warning by specifying " + "engine='python'."), ParserWarning, stacklevel=find_stack_level(), ) @@ -1604,11 +1578,11 @@ def _clean_options( parser_default = _c_parser_defaults.get(arg, parser_defaults[arg]) depr_default = _deprecated_defaults[arg] if result.get(arg, depr_default) != depr_default.default_value: - msg = ( - f"The {arg} argument has been deprecated and will be " - f"removed in a future version. {depr_default.msg}\n\n" - ) - warnings.warn(msg, FutureWarning, stacklevel=find_stack_level()) + msg = (f"The {arg} argument has been deprecated and will be " + f"removed in a future version. {depr_default.msg}\n\n") + warnings.warn(msg, + FutureWarning, + stacklevel=find_stack_level()) else: result[arg] = parser_default @@ -1624,10 +1598,8 @@ def _clean_options( # type conversion-related if converters is not None: if not isinstance(converters, dict): - raise TypeError( - "Type converters must be a dict or subclass, " - f"input was a {type(converters).__name__}" - ) + raise TypeError("Type converters must be a dict or subclass, " + f"input was a {type(converters).__name__}") else: converters = {} @@ -1642,8 +1614,7 @@ def _clean_options( # pyarrow expects skiprows to be passed as an integer raise ValueError( "skiprows argument must be an integer when using " - "engine='pyarrow'" - ) + "engine='pyarrow'") else: if is_integer(skiprows): skiprows = list(range(skiprows)) @@ -1661,7 +1632,8 @@ def _clean_options( # Default for squeeze is none since we need to check # if user sets it. We then set to False to preserve # previous behavior. - result["squeeze"] = False if options["squeeze"] is None else options["squeeze"] + result["squeeze"] = False if options["squeeze"] is None else options[ + "squeeze"] return result, engine @@ -1737,8 +1709,7 @@ def read(self, nrows: int | None = None) -> DataFrame: columns, col_dict, ) = self._engine.read( # type: ignore[attr-defined] - nrows - ) + nrows) except Exception: self.close() raise @@ -1992,20 +1963,17 @@ def _refine_defaults_read( # the comparison to dialect values by checking if default values # for BOTH "delimiter" and "sep" were provided. if dialect is not None: - kwds["sep_override"] = delimiter is None and ( - sep is lib.no_default or sep == delim_default - ) + kwds["sep_override"] = delimiter is None and (sep is lib.no_default + or sep == delim_default) if delimiter and (sep is not lib.no_default): - raise ValueError("Specified a sep and a delimiter; you can only specify one.") + raise ValueError( + "Specified a sep and a delimiter; you can only specify one.") - if ( - names is not None - and names is not lib.no_default - and prefix is not None - and prefix is not lib.no_default - ): - raise ValueError("Specified named and prefix; you can only specify one.") + if (names is not None and names is not lib.no_default + and prefix is not None and prefix is not lib.no_default): + raise ValueError( + "Specified named and prefix; you can only specify one.") kwds["names"] = None if names is lib.no_default else names kwds["prefix"] = None if prefix is lib.no_default else prefix @@ -2015,17 +1983,14 @@ def _refine_defaults_read( delimiter = sep if delim_whitespace and (delimiter is not lib.no_default): - raise ValueError( - "Specified a delimiter with both sep and " - "delim_whitespace=True; you can only specify one." - ) + raise ValueError("Specified a delimiter with both sep and " + "delim_whitespace=True; you can only specify one.") if delimiter == "\n": raise ValueError( r"Specified \n as separator or delimiter. This forces the python engine " "which does not accept a line terminator. Hence it is not allowed to use " - "the line terminator as separator.", - ) + "the line terminator as separator.", ) if delimiter is lib.no_default: # assign default separator value @@ -2048,8 +2013,7 @@ def _refine_defaults_read( if error_bad_lines is not None or warn_bad_lines is not None: raise ValueError( "Both on_bad_lines and error_bad_lines/warn_bad_lines are set. " - "Please only set on_bad_lines." - ) + "Please only set on_bad_lines.") if on_bad_lines == "error": kwds["on_bad_lines"] = ParserBase.BadLineHandleMethod.ERROR elif on_bad_lines == "warn": @@ -2063,7 +2027,8 @@ def _refine_defaults_read( ) kwds["on_bad_lines"] = on_bad_lines else: - raise ValueError(f"Argument {on_bad_lines} is invalid for on_bad_lines") + raise ValueError( + f"Argument {on_bad_lines} is invalid for on_bad_lines") else: if error_bad_lines is not None: # Must check is_bool, because other stuff(e.g. non-empty lists) eval to true @@ -2077,9 +2042,11 @@ def _refine_defaults_read( # None doesn't work because backwards-compatibility reasons validate_bool_kwarg(warn_bad_lines, "warn_bad_lines") if warn_bad_lines: - kwds["on_bad_lines"] = ParserBase.BadLineHandleMethod.WARN + kwds[ + "on_bad_lines"] = ParserBase.BadLineHandleMethod.WARN else: - kwds["on_bad_lines"] = ParserBase.BadLineHandleMethod.SKIP + kwds[ + "on_bad_lines"] = ParserBase.BadLineHandleMethod.SKIP else: # Backwards compat, when only error_bad_lines = false, we warn kwds["on_bad_lines"] = ParserBase.BadLineHandleMethod.WARN @@ -2168,11 +2135,9 @@ def _merge_with_dialect_properties( # Don't warn if the default parameter was passed in, # even if it conflicts with the dialect (gh-23761). if provided not in (parser_default, dialect_val): - msg = ( - f"Conflicting values for '{param}': '{provided}' was " - f"provided, but the dialect specifies '{dialect_val}'. " - "Using the dialect-specified value." - ) + msg = (f"Conflicting values for '{param}': '{provided}' was " + f"provided, but the dialect specifies '{dialect_val}'. " + "Using the dialect-specified value.") # Annoying corner case for not warning about # conflicts between dialect and delimiter parameter. @@ -2181,9 +2146,9 @@ def _merge_with_dialect_properties( conflict_msgs.append(msg) if conflict_msgs: - warnings.warn( - "\n\n".join(conflict_msgs), ParserWarning, stacklevel=find_stack_level() - ) + warnings.warn("\n\n".join(conflict_msgs), + ParserWarning, + stacklevel=find_stack_level()) kwds[param] = dialect_val return kwds diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 24a69b4e68d14..4164130c43029 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -107,7 +107,6 @@ from pandas.core.internals import Block - # versioning attribute _version = "0.15.2" @@ -157,8 +156,8 @@ def _ensure_term(where, scope_level: int): level = scope_level + 1 if isinstance(where, (list, tuple)): where = [ - Term(term, scope_level=level + 1) if maybe_expression(term) else term - for term in where + Term(term, scope_level=level + + 1) if maybe_expression(term) else term for term in where if term is not None ] elif maybe_expression(where): @@ -226,7 +225,10 @@ class DuplicateWarning(Warning): """ with config.config_prefix("io.hdf"): - config.register_option("dropna_table", False, dropna_doc, validator=config.is_bool) + config.register_option("dropna_table", + False, + dropna_doc, + validator=config.is_bool) config.register_option( "default_format", None, @@ -252,8 +254,7 @@ def _tables(): # depending on the HDF5 version with suppress(AttributeError): _table_file_open_policy_is_strict = ( - tables.file._FILE_OPEN_POLICY == "strict" - ) + tables.file._FILE_OPEN_POLICY == "strict") return _table_mod @@ -309,9 +310,10 @@ def to_hdf( path_or_buf = stringify_path(path_or_buf) if isinstance(path_or_buf, str): - with HDFStore( - path_or_buf, mode=mode, complevel=complevel, complib=complib - ) as store: + with HDFStore(path_or_buf, + mode=mode, + complevel=complevel, + complib=complib) as store: f(store) else: f(path_or_buf) @@ -399,8 +401,7 @@ def read_hdf( if mode not in ["r", "r+", "a"]: raise ValueError( f"mode {mode} is not allowed while performing a read. " - f"Allowed modes are r, r+ and a." - ) + f"Allowed modes are r, r+ and a.") # grab the scope if where is not None: where = _ensure_term(where, scope_level=1) @@ -415,8 +416,7 @@ def read_hdf( path_or_buf = stringify_path(path_or_buf) if not isinstance(path_or_buf, str): raise NotImplementedError( - "Support for generic buffers has not been implemented." - ) + "Support for generic buffers has not been implemented.") try: exists = os.path.exists(path_or_buf) @@ -438,8 +438,7 @@ def read_hdf( if len(groups) == 0: raise ValueError( "Dataset(s) incompatible with Pandas data types, " - "not table, or no datasets found in HDF5 file." - ) + "not table, or no datasets found in HDF5 file.") candidate_only_group = groups[0] # For the HDF file to have only one dataset, all other groups @@ -448,10 +447,8 @@ def read_hdf( # before their children.) for group_to_check in groups[1:]: if not _is_metadata_of(group_to_check, candidate_only_group): - raise ValueError( - "key must be provided when HDF5 " - "file contains multiple datasets." - ) + raise ValueError("key must be provided when HDF5 " + "file contains multiple datasets.") key = candidate_only_group._v_pathname return store.select( key, @@ -621,8 +618,7 @@ def __getattr__(self, name: str): except (KeyError, ClosedFileError): pass raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'" - ) + f"'{type(self).__name__}' object has no attribute '{name}'") def __contains__(self, key: str) -> bool: """ @@ -677,7 +673,8 @@ def keys(self, include: str = "pandas") -> list[str]: elif include == "native": assert self._handle is not None # mypy return [ - n._v_pathname for n in self._handle.walk_nodes("/", classname="Table") + n._v_pathname + for n in self._handle.walk_nodes("/", classname="Table") ] raise ValueError( f"`include` should be either 'pandas' or 'native' but is '{include}'" @@ -727,8 +724,7 @@ def open(self, mode: str = "a", **kwargs): if self.is_open: raise PossibleDataLossError( f"Re-opening the file [{self._path}] with mode [{self._mode}] " - "will delete the current file!" - ) + "will delete the current file!") self._mode = mode @@ -737,15 +733,13 @@ def open(self, mode: str = "a", **kwargs): self.close() if self._complevel and self._complevel > 0: - self._filters = _tables().Filters( - self._complevel, self._complib, fletcher32=self._fletcher32 - ) + self._filters = _tables().Filters(self._complevel, + self._complib, + fletcher32=self._fletcher32) if _table_file_open_policy_is_strict and self.is_open: - msg = ( - "Cannot open HDF5 file, which is already opened, " - "even in read-only mode." - ) + msg = ("Cannot open HDF5 file, which is already opened, " + "even in read-only mode.") raise ValueError(msg) self._handle = tables.open_file(self._path, self._mode, **kwargs) @@ -867,7 +861,10 @@ def select( # function to call on iteration def func(_start, _stop, _where): - return s.read(start=_start, stop=_stop, where=_where, columns=columns) + return s.read(start=_start, + stop=_stop, + where=_where, + columns=columns) # create the iterator it = TableIterator( @@ -1036,13 +1033,13 @@ def select_as_multiple( if not t.is_table: raise TypeError( f"object [{t.pathname}] is not a table, and cannot be used in all " - "select as multiple" - ) + "select as multiple") if nrows is None: nrows = t.nrows elif t.nrows != nrows: - raise ValueError("all tables must have exactly the same nrows!") + raise ValueError( + "all tables must have exactly the same nrows!") # The isinstance checks here are redundant with the check above, # but necessary for mypy; see GH#29757 @@ -1061,7 +1058,8 @@ def func(_start, _stop, _where): ] # concat and return - return concat(objs, axis=axis, verify_integrity=False)._consolidate() + return concat(objs, axis=axis, + verify_integrity=False)._consolidate() # create the iterator it = TableIterator( @@ -1200,8 +1198,7 @@ def remove(self, key: str, where=None, start=None, stop=None): else: if not s.is_table: raise ValueError( - "can only remove with where on objects written as tables" - ) + "can only remove with where on objects written as tables") return s.delete(where=where, start=start, stop=stop) def append( @@ -1323,14 +1320,12 @@ def append_to_multiple( if axes is not None: raise TypeError( "axes is currently not accepted as a parameter to append_to_multiple; " - "you can create the tables independently instead" - ) + "you can create the tables independently instead") if not isinstance(d, dict): raise ValueError( "append_to_multiple must have a dictionary specified as the " - "way to split the value" - ) + "way to split the value") if selector not in d: raise ValueError( @@ -1379,12 +1374,15 @@ def append_to_multiple( # compute the val val = value.reindex(v, axis=axis) - filtered = ( - {key: value for (key, value) in min_itemsize.items() if key in v} - if min_itemsize is not None - else None - ) - self.append(k, val, data_columns=dc, min_itemsize=filtered, **kwargs) + filtered = ({ + key: value + for (key, value) in min_itemsize.items() if key in v + } if min_itemsize is not None else None) + self.append(k, + val, + data_columns=dc, + min_itemsize=filtered, + **kwargs) def create_table_index( self, @@ -1423,7 +1421,8 @@ def create_table_index( return if not isinstance(s, Table): - raise TypeError("cannot create table index on a Fixed format store") + raise TypeError( + "cannot create table index on a Fixed format store") s.create_index(columns=columns, optlevel=optlevel, kind=kind) def groups(self): @@ -1442,16 +1441,11 @@ def groups(self): assert self._handle is not None # for mypy assert _table_mod is not None # for mypy return [ - g - for g in self._handle.walk_groups() - if ( - not isinstance(g, _table_mod.link.Link) - and ( - getattr(g._v_attrs, "pandas_type", None) - or getattr(g, "table", None) - or (isinstance(g, _table_mod.table.Table) and g._v_name != "table") - ) - ) + g for g in self._handle.walk_groups() + if (not isinstance(g, _table_mod.link.Link) and ( + getattr(g._v_attrs, "pandas_type", None) or getattr( + g, "table", None) or (isinstance(g, _table_mod.table.Table) + and g._v_name != "table"))) ] def walk(self, where="/"): @@ -1556,9 +1550,11 @@ def copy( ------- open file handle of the new store """ - new_store = HDFStore( - file, mode=mode, complib=complib, complevel=complevel, fletcher32=fletcher32 - ) + new_store = HDFStore(file, + mode=mode, + complib=complib, + complevel=complevel, + fletcher32=fletcher32) if keys is None: keys = list(self.keys()) if not isinstance(keys, (tuple, list)): @@ -1611,7 +1607,8 @@ def info(self) -> str: s = self.get_storer(k) if s is not None: keys.append(pprint_thing(s.pathname or k)) - values.append(pprint_thing(s or "invalid_HDFStore node")) + values.append( + pprint_thing(s or "invalid_HDFStore node")) except AssertionError: # surface any assertion errors for e.g. debugging raise @@ -1641,7 +1638,8 @@ def _validate_format(self, format: str) -> str: try: format = _FORMAT_MAP[format.lower()] except KeyError as err: - raise TypeError(f"invalid HDFStore format specified [{format}]") from err + raise TypeError( + f"invalid HDFStore format specified [{format}]") from err return format @@ -1663,8 +1661,7 @@ def error(t): # return instead of raising so mypy can tell where we are raising return TypeError( f"cannot properly create the storer for: [{t}] [group->" - f"{group},value->{type(value)},format->{format}" - ) + f"{group},value->{type(value)},format->{format}") pt = _ensure_decoded(getattr(group._v_attrs, "pandas_type", None)) tt = _ensure_decoded(getattr(group._v_attrs, "table_type", None)) @@ -1675,15 +1672,13 @@ def error(t): _tables() assert _table_mod is not None # for mypy if getattr(group, "table", None) or isinstance( - group, _table_mod.table.Table - ): + group, _table_mod.table.Table): pt = "frame_table" tt = "generic_table" else: raise TypeError( "cannot create a storer if the object is not existing " - "nor a value are passed" - ) + "nor a value are passed") else: if isinstance(value, Series): pt = "series" @@ -1765,11 +1760,16 @@ def _write_to_group( group = self._identify_group(key, append) - s = self._create_storer(group, format, value, encoding=encoding, errors=errors) + s = self._create_storer(group, + format, + value, + encoding=encoding, + errors=errors) if append: # raise if we are trying to append to a Fixed format, # or a table that exists (and we are putting) - if not s.is_table or (s.is_table and format == "fixed" and s.is_exists): + if not s.is_table or (s.is_table and format == "fixed" + and s.is_exists): raise ValueError("Can only append to Tables") if not s.is_exists: s.set_object_info() @@ -1777,7 +1777,8 @@ def _write_to_group( s.set_object_info() if not s.is_table and complib: - raise ValueError("Compression not supported on Fixed format stores") + raise ValueError( + "Compression not supported on Fixed format stores") # write the object s.write( @@ -1934,7 +1935,8 @@ def get_result(self, coordinates: bool = False): # return the actual iterator if self.chunksize is not None: if not isinstance(self.s, Table): - raise TypeError("can only use an iterator or chunksize on a table") + raise TypeError( + "can only use an iterator or chunksize on a table") self.coordinates = self.s.read_coordinates(where=self.where) @@ -1944,9 +1946,9 @@ def get_result(self, coordinates: bool = False): if coordinates: if not isinstance(self.s, Table): raise TypeError("can only read_coordinates on a table") - where = self.s.read_coordinates( - where=self.where, start=self.start, stop=self.stop - ) + where = self.s.read_coordinates(where=self.where, + start=self.start, + stop=self.stop) else: where = self.where @@ -2038,21 +2040,18 @@ def set_pos(self, pos: int): def __repr__(self) -> str: temp = tuple( - map(pprint_thing, (self.name, self.cname, self.axis, self.pos, self.kind)) - ) - return ",".join( - [ - f"{key}->{value}" - for key, value in zip(["name", "cname", "axis", "pos", "kind"], temp) - ] - ) + map(pprint_thing, + (self.name, self.cname, self.axis, self.pos, self.kind))) + return ",".join([ + f"{key}->{value}" for key, value in zip( + ["name", "cname", "axis", "pos", "kind"], temp) + ]) def __eq__(self, other: Any) -> bool: """compare 2 col items""" return all( getattr(self, a, None) == getattr(other, a, None) - for a in ["name", "cname", "axis", "pos"] - ) + for a in ["name", "cname", "axis", "pos"]) def __ne__(self, other) -> bool: return not self.__eq__(other) @@ -2085,7 +2084,8 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): kwargs["freq"] = _ensure_decoded(self.freq) factory: type[Index] | type[DatetimeIndex] = Index - if is_datetime64_dtype(values.dtype) or is_datetime64tz_dtype(values.dtype): + if is_datetime64_dtype(values.dtype) or is_datetime64tz_dtype( + values.dtype): factory = DatetimeIndex elif values.dtype == "i8" and "freq" in kwargs: # PeriodIndex data is stored as i8 @@ -2093,8 +2093,7 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): # "Callable[[Any, KwArg(Any)], PeriodIndex]", variable has type # "Union[Type[Index], Type[DatetimeIndex]]") factory = lambda x, **kwds: PeriodIndex( # type: ignore[assignment] - ordinal=x, **kwds - ) + ordinal=x, **kwds) # making an Index instance could throw a number of different errors try: @@ -2144,7 +2143,8 @@ def maybe_set_size(self, min_itemsize=None): min_itemsize = min_itemsize.get(self.name) if min_itemsize is not None and self.typ.itemsize < min_itemsize: - self.typ = _tables().StringCol(itemsize=min_itemsize, pos=self.pos) + self.typ = _tables().StringCol(itemsize=min_itemsize, + pos=self.pos) def validate_names(self): pass @@ -2170,8 +2170,7 @@ def validate_col(self, itemsize=None): f"Trying to store a string with len [{itemsize}] in " f"[{self.cname}] column but\nthis column has a limit of " f"[{c.itemsize}]!\nConsider using min_itemsize to " - "preset the sizes on these columns" - ) + "preset the sizes on these columns") return c.itemsize return None @@ -2200,9 +2199,9 @@ def update_info(self, info): # frequency/name just warn if key in ["freq", "index_name"]: ws = attribute_conflict_doc % (key, existing_value, value) - warnings.warn( - ws, AttributeConflictWarning, stacklevel=find_stack_level() - ) + warnings.warn(ws, + AttributeConflictWarning, + stacklevel=find_stack_level()) # reset idx[key] = None @@ -2212,8 +2211,7 @@ def update_info(self, info): raise ValueError( f"invalid info for [{self.name}] for [{key}], " f"existing_value [{existing_value}] conflicts with " - f"new value [{value}]" - ) + f"new value [{value}]") else: if value is not None or existing_value is not None: idx[key] = value @@ -2233,15 +2231,10 @@ def validate_metadata(self, handler: AppendableTable): if self.meta == "category": new_metadata = self.metadata cur_metadata = handler.read_metadata(self.cname) - if ( - new_metadata is not None - and cur_metadata is not None - and not array_equivalent(new_metadata, cur_metadata) - ): - raise ValueError( - "cannot append a categorical with " - "different categories to the existing" - ) + if (new_metadata is not None and cur_metadata is not None + and not array_equivalent(new_metadata, cur_metadata)): + raise ValueError("cannot append a categorical with " + "different categories to the existing") def write_metadata(self, handler: AppendableTable): """set the meta data""" @@ -2337,23 +2330,18 @@ def meta_attr(self) -> str: def __repr__(self) -> str: temp = tuple( - map( - pprint_thing, (self.name, self.cname, self.dtype, self.kind, self.shape) - ) - ) - return ",".join( - [ - f"{key}->{value}" - for key, value in zip(["name", "cname", "dtype", "kind", "shape"], temp) - ] - ) + map(pprint_thing, + (self.name, self.cname, self.dtype, self.kind, self.shape))) + return ",".join([ + f"{key}->{value}" for key, value in zip( + ["name", "cname", "dtype", "kind", "shape"], temp) + ]) def __eq__(self, other: Any) -> bool: """compare 2 col items""" return all( getattr(self, a, None) == getattr(other, a, None) - for a in ["name", "cname", "dtype", "pos"] - ) + for a in ["name", "cname", "dtype", "pos"]) def set_data(self, data: ArrayLike): assert data is not None @@ -2445,8 +2433,10 @@ def validate_attr(self, append): """validate that we have the same order as the existing & same dtype""" if append: existing_fields = getattr(self.attrs, self.kind_attr, None) - if existing_fields is not None and existing_fields != list(self.values): - raise ValueError("appended items do not match existing items in table!") + if existing_fields is not None and existing_fields != list( + self.values): + raise ValueError( + "appended items do not match existing items in table!") existing_dtype = getattr(self.attrs, self.dtype_attr, None) if existing_dtype is not None and existing_dtype != self.dtype: @@ -2509,12 +2499,10 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): elif dtype == "date": try: converted = np.asarray( - [date.fromordinal(v) for v in converted], dtype=object - ) + [date.fromordinal(v) for v in converted], dtype=object) except ValueError: converted = np.asarray( - [date.fromtimestamp(v) for v in converted], dtype=object - ) + [date.fromtimestamp(v) for v in converted], dtype=object) elif meta == "category": # we have a categorical @@ -2536,9 +2524,9 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): categories = categories[~mask] codes[codes != -1] -= mask.astype(int).cumsum()._values - converted = Categorical.from_codes( - codes, categories=categories, ordered=ordered - ) + converted = Categorical.from_codes(codes, + categories=categories, + ordered=ordered) else: @@ -2549,9 +2537,10 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): # convert nans / decode if _ensure_decoded(kind) == "string": - converted = _unconvert_string_array( - converted, nan_rep=nan_rep, encoding=encoding, errors=errors - ) + converted = _unconvert_string_array(converted, + nan_rep=nan_rep, + encoding=encoding, + errors=errors) return self.values, converted @@ -2636,23 +2625,26 @@ def __init__( @property def is_old_version(self) -> bool: - return self.version[0] <= 0 and self.version[1] <= 10 and self.version[2] < 1 + return self.version[0] <= 0 and self.version[1] <= 10 and self.version[ + 2] < 1 @property def version(self) -> tuple[int, int, int]: """compute and set our version""" - version = _ensure_decoded(getattr(self.group._v_attrs, "pandas_version", None)) + version = _ensure_decoded( + getattr(self.group._v_attrs, "pandas_version", None)) try: version = tuple(int(x) for x in version.split(".")) if len(version) == 2: - version = version + (0,) + version = version + (0, ) except AttributeError: version = (0, 0, 0) return version @property def pandas_type(self): - return _ensure_decoded(getattr(self.group._v_attrs, "pandas_type", None)) + return _ensure_decoded( + getattr(self.group._v_attrs, "pandas_type", None)) def __repr__(self) -> str: """return a pretty representation of myself""" @@ -2752,15 +2744,16 @@ def read( stop: int | None = None, ): raise NotImplementedError( - "cannot read on an abstract storer: subclasses should implement" - ) + "cannot read on an abstract storer: subclasses should implement") def write(self, **kwargs): raise NotImplementedError( - "cannot write on an abstract storer: subclasses should implement" - ) + "cannot write on an abstract storer: subclasses should implement") - def delete(self, where=None, start: int | None = None, stop: int | None = None): + def delete(self, + where=None, + start: int | None = None, + stop: int | None = None): """ support fully deleting the node in its entirety (only) - where specification must be None @@ -2791,8 +2784,7 @@ def _alias_to_class(self, alias): def _get_index_factory(self, attrs): index_class = self._alias_to_class( - _ensure_decoded(getattr(attrs, "index_class", "")) - ) + _ensure_decoded(getattr(attrs, "index_class", ""))) factory: Callable @@ -2869,7 +2861,10 @@ def get_attrs(self): def write(self, obj, **kwargs): self.set_attrs() - def read_array(self, key: str, start: int | None = None, stop: int | None = None): + def read_array(self, + key: str, + start: int | None = None, + stop: int | None = None): """read an array for the specified node (off of group""" import tables @@ -2903,9 +2898,10 @@ def read_array(self, key: str, start: int | None = None, stop: int | None = None else: return ret - def read_index( - self, key: str, start: int | None = None, stop: int | None = None - ) -> Index: + def read_index(self, + key: str, + start: int | None = None, + stop: int | None = None) -> Index: variety = _ensure_decoded(getattr(self.attrs, f"{key}_variety")) if variety == "multi": @@ -2923,7 +2919,8 @@ def write_index(self, key: str, index: Index): self.write_multi_index(key, index) else: setattr(self.attrs, f"{key}_variety", "regular") - converted = _convert_index("index", index, self.encoding, self.errors) + converted = _convert_index("index", index, self.encoding, + self.errors) self.write_array(key, converted.values) @@ -2943,16 +2940,17 @@ def write_index(self, key: str, index: Index): def write_multi_index(self, key: str, index: MultiIndex): setattr(self.attrs, f"{key}_nlevels", index.nlevels) - for i, (lev, level_codes, name) in enumerate( - zip(index.levels, index.codes, index.names) - ): + for i, (lev, level_codes, + name) in enumerate(zip(index.levels, index.codes, + index.names)): # write the level if is_extension_array_dtype(lev): raise NotImplementedError( "Saving a MultiIndex with an extension dtype is not supported." ) level_key = f"{key}_level{i}" - conv_level = _convert_index(level_key, lev, self.encoding, self.errors) + conv_level = _convert_index(level_key, lev, self.encoding, + self.errors) self.write_array(level_key, conv_level.values) node = getattr(self.group, level_key) node._v_attrs.kind = conv_level.kind @@ -2965,9 +2963,10 @@ def write_multi_index(self, key: str, index: MultiIndex): label_key = f"{key}_label{i}" self.write_array(label_key, level_codes) - def read_multi_index( - self, key: str, start: int | None = None, stop: int | None = None - ) -> MultiIndex: + def read_multi_index(self, + key: str, + start: int | None = None, + stop: int | None = None) -> MultiIndex: nlevels = getattr(self.attrs, f"{key}_nlevels") levels = [] @@ -2984,18 +2983,21 @@ def read_multi_index( level_codes = self.read_array(label_key, start=start, stop=stop) codes.append(level_codes) - return MultiIndex( - levels=levels, codes=codes, names=names, verify_integrity=True - ) + return MultiIndex(levels=levels, + codes=codes, + names=names, + verify_integrity=True) - def read_index_node( - self, node: Node, start: int | None = None, stop: int | None = None - ) -> Index: + def read_index_node(self, + node: Node, + start: int | None = None, + stop: int | None = None) -> Index: data = node[start:stop] # If the index was an empty array write_array_empty() will # have written a sentinel. Here we replace it with the original. if "shape" in node._v_attrs and np.prod(node._v_attrs.shape) == 0: - data = np.empty(node._v_attrs.shape, dtype=node._v_attrs.value_type) + data = np.empty(node._v_attrs.shape, + dtype=node._v_attrs.value_type) kind = _ensure_decoded(node._v_attrs.kind) name = None @@ -3008,17 +3010,19 @@ def read_index_node( if kind == "date": index = factory( - _unconvert_index( - data, kind, encoding=self.encoding, errors=self.errors - ), + _unconvert_index(data, + kind, + encoding=self.encoding, + errors=self.errors), dtype=object, **kwargs, ) else: index = factory( - _unconvert_index( - data, kind, encoding=self.encoding, errors=self.errors - ), + _unconvert_index(data, + kind, + encoding=self.encoding, + errors=self.errors), **kwargs, ) @@ -3029,15 +3033,16 @@ def read_index_node( def write_array_empty(self, key: str, value: ArrayLike): """write a 0-len array""" # ugly hack for length 0 axes - arr = np.empty((1,) * value.ndim) + arr = np.empty((1, ) * value.ndim) self._handle.create_array(self.group, key, arr) node = getattr(self.group, key) node._v_attrs.value_type = str(value.dtype) node._v_attrs.shape = value.shape - def write_array( - self, key: str, obj: DataFrame | Series, items: Index | None = None - ) -> None: + def write_array(self, + key: str, + obj: DataFrame | Series, + items: Index | None = None) -> None: # TODO: we only have a few tests that get here, the only EA # that gets passed is DatetimeArray, and we never have # both self._filters and EA @@ -3054,8 +3059,7 @@ def write_array( if is_categorical_dtype(value.dtype): raise NotImplementedError( "Cannot store a category dtype in a HDF5 dataset that uses format=" - '"fixed". Use format="table".' - ) + '"fixed". Use format="table".') if not empty_array: if hasattr(value, "T"): # ExtensionArrays (1d) may not have transpose. @@ -3074,9 +3078,11 @@ def write_array( # create an empty chunked array and fill it from value if not empty_array: - ca = self._handle.create_carray( - self.group, key, atom, value.shape, filters=self._filters - ) + ca = self._handle.create_carray(self.group, + key, + atom, + value.shape, + filters=self._filters) ca[:] = value else: @@ -3092,9 +3098,12 @@ def write_array( pass else: ws = performance_doc % (inferred_type, key, items) - warnings.warn(ws, PerformanceWarning, stacklevel=find_stack_level()) + warnings.warn(ws, + PerformanceWarning, + stacklevel=find_stack_level()) - vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom()) + vlarr = self._handle.create_vlarray(self.group, key, + _tables().ObjectAtom()) vlarr.append(value) elif is_datetime64_dtype(value.dtype): @@ -3107,7 +3116,9 @@ def write_array( # error: Item "ExtensionArray" of "Union[Any, ExtensionArray]" has no # attribute "asi8" self._handle.create_array( - self.group, key, value.asi8 # type: ignore[union-attr] + self.group, + key, + value.asi8 # type: ignore[union-attr] ) node = getattr(self.group, key) @@ -3135,7 +3146,7 @@ class SeriesFixed(GenericFixed): @property def shape(self): try: - return (len(self.group.values),) + return (len(self.group.values), ) except (TypeError, AttributeError): return None @@ -3180,7 +3191,7 @@ def shape(self) -> Shape | None: node = self.group.block0_values shape = getattr(node, "shape", None) if shape is not None: - shape = list(shape[0 : (ndim - 1)]) + shape = list(shape[0:(ndim - 1)]) else: shape = [] @@ -3214,7 +3225,9 @@ def read( for i in range(self.nblocks): blk_items = self.read_index(f"block{i}_items") - values = self.read_array(f"block{i}_values", start=_start, stop=_stop) + values = self.read_array(f"block{i}_values", + start=_start, + stop=_stop) columns = items[items.get_indexer(blk_items)] df = DataFrame(values.T, columns=columns, index=axes[1]) @@ -3241,7 +3254,8 @@ def write(self, obj, **kwargs): self.attrs.ndim = data.ndim for i, ax in enumerate(data.axes): if i == 0 and (not ax.is_unique): - raise ValueError("Columns index has to be unique for fixed format") + raise ValueError( + "Columns index has to be unique for fixed format") self.write_index(f"axis{i}", ax) # Supporting mixed-type DataFrame objects...nontrivial @@ -3332,11 +3346,9 @@ def __repr__(self) -> str: ver = f"[{jver}]" jindex_axes = ",".join([a.name for a in self.index_axes]) - return ( - f"{self.pandas_type:12.12}{ver} " - f"(typ->{self.table_type_short},nrows->{self.nrows}," - f"ncols->{self.ncols},indexers->[{jindex_axes}]{dc})" - ) + return (f"{self.pandas_type:12.12}{ver} " + f"(typ->{self.table_type_short},nrows->{self.nrows}," + f"ncols->{self.ncols},indexers->[{jindex_axes}]{dc})") def __getitem__(self, c: str): """return the axis for c""" @@ -3351,10 +3363,8 @@ def validate(self, other): return if other.table_type != self.table_type: - raise TypeError( - "incompatible table_type with existing " - f"[{other.table_type} - {self.table_type}]" - ) + raise TypeError("incompatible table_type with existing " + f"[{other.table_type} - {self.table_type}]") for c in ["index_axes", "non_index_axes", "values_axes"]: sv = getattr(self, c, None) @@ -3370,14 +3380,12 @@ def validate(self, other): if sax != oax: raise ValueError( f"invalid combination of [{c}] on appending data " - f"[{sax}] vs current table [{oax}]" - ) + f"[{sax}] vs current table [{oax}]") # should never get here raise Exception( f"invalid combination of [{c}] on appending data [{sv}] vs " - f"current table [{ov}]" - ) + f"current table [{ov}]") @property def is_multi_index(self) -> bool: @@ -3385,8 +3393,7 @@ def is_multi_index(self) -> bool: return isinstance(self.levels, list) def validate_multiindex( - self, obj: DataFrame | Series - ) -> tuple[DataFrame, list[Hashable]]: + self, obj: DataFrame | Series) -> tuple[DataFrame, list[Hashable]]: """ validate that we can store the multi-index; reset and return the new object @@ -3448,8 +3455,7 @@ def data_orientation(self): itertools.chain( [int(a[0]) for a in self.non_index_axes], [int(a.axis) for a in self.index_axes], - ) - ) + )) def queryables(self) -> dict[str, Any]: """return a dict of the kinds allowable columns for this object""" @@ -3459,9 +3465,8 @@ def queryables(self) -> dict[str, Any]: # compute the values_axes queryables d1 = [(a.cname, a) for a in self.index_axes] d2 = [(axis_names[axis], None) for axis, values in self.non_index_axes] - d3 = [ - (v.cname, v) for v in self.values_axes if v.name in set(self.data_columns) - ] + d3 = [(v.cname, v) for v in self.values_axes + if v.name in set(self.data_columns)] # error: Unsupported operand types for + ("List[Tuple[str, IndexCol]]" and # "List[Tuple[str, None]]") @@ -3528,13 +3533,17 @@ def get_attrs(self): self.errors = _ensure_decoded(getattr(self.attrs, "errors", "strict")) self.levels: list[Hashable] = getattr(self.attrs, "levels", None) or [] self.index_axes = [a for a in self.indexables if a.is_an_indexable] - self.values_axes = [a for a in self.indexables if not a.is_an_indexable] + self.values_axes = [ + a for a in self.indexables if not a.is_an_indexable + ] def validate_version(self, where=None): """are we trying to operate on an old version?""" if where is not None: - if self.version[0] <= 0 and self.version[1] <= 10 and self.version[2] < 1: - ws = incompatibility_doc % ".".join([str(x) for x in self.version]) + if self.version[0] <= 0 and self.version[1] <= 10 and self.version[ + 2] < 1: + ws = incompatibility_doc % ".".join( + [str(x) for x in self.version]) warnings.warn(ws, IncompatibilityWarning) def validate_min_itemsize(self, min_itemsize): @@ -3556,8 +3565,7 @@ def validate_min_itemsize(self, min_itemsize): if k not in q: raise ValueError( f"min_itemsize has the key [{k}] which is not an axis or " - "data_column" - ) + "data_column") @cache_readonly def indexables(self): @@ -3631,11 +3639,15 @@ def f(i, c): # Note: the definition of `values_cols` ensures that each # `c` below is a str. - _indexables.extend([f(i, c) for i, c in enumerate(self.attrs.values_cols)]) + _indexables.extend( + [f(i, c) for i, c in enumerate(self.attrs.values_cols)]) return _indexables - def create_index(self, columns=None, optlevel=None, kind: str | None = None): + def create_index(self, + columns=None, + optlevel=None, + kind: str | None = None): """ Create a pytables index on the specified columns. @@ -3708,8 +3720,7 @@ def create_index(self, columns=None, optlevel=None, kind: str | None = None): "cannot be indexed when using table format. Either use " "fixed format, set index=False, or do not include " "the columns containing complex values to " - "data_columns when initializing the table." - ) + "data_columns when initializing the table.") v.create_index(**kw) elif c in self.non_index_axes[0][1]: # GH 28156 @@ -3720,8 +3731,10 @@ def create_index(self, columns=None, optlevel=None, kind: str | None = None): ) def _read_axes( - self, where, start: int | None = None, stop: int | None = None - ) -> list[tuple[ArrayLike, ArrayLike]]: + self, + where, + start: int | None = None, + stop: int | None = None) -> list[tuple[ArrayLike, ArrayLike]]: """ Create the axes sniffed from the table. @@ -3758,7 +3771,8 @@ def get_object(cls, obj, transposed: bool): """return the data for this obj""" return obj - def validate_data_columns(self, data_columns, min_itemsize, non_index_axes): + def validate_data_columns(self, data_columns, min_itemsize, + non_index_axes): """ take the input data_columns and min_itemize and create a data columns spec @@ -3769,10 +3783,8 @@ def validate_data_columns(self, data_columns, min_itemsize, non_index_axes): axis, axis_labels = non_index_axes[0] info = self.info.get(axis, {}) if info.get("type") == "MultiIndex" and data_columns: - raise ValueError( - f"cannot use a multi-index on axis [{axis}] with " - f"data_columns {data_columns}" - ) + raise ValueError(f"cannot use a multi-index on axis [{axis}] with " + f"data_columns {data_columns}") # evaluate the passed data_columns, True == use all columns # take only valid axis labels @@ -3785,13 +3797,10 @@ def validate_data_columns(self, data_columns, min_itemsize, non_index_axes): if isinstance(min_itemsize, dict): existing_data_columns = set(data_columns) data_columns = list(data_columns) # ensure we do not modify - data_columns.extend( - [ - k - for k in min_itemsize.keys() - if k != "values" and k not in existing_data_columns - ] - ) + data_columns.extend([ + k for k in min_itemsize.keys() + if k != "values" and k not in existing_data_columns + ]) # return valid columns in the order of our axis return [c for c in data_columns if c in axis_labels] @@ -3832,8 +3841,7 @@ def _create_axes( group = self.group._v_name raise TypeError( f"cannot properly create the storer for: [group->{group}," - f"value->{type(obj)}]" - ) + f"value->{type(obj)}]") # set the default axes if needed if axes is None: @@ -3858,8 +3866,7 @@ def _create_axes( # currently support on ndim-1 axes if len(axes) != self.ndim - 1: raise ValueError( - "currently only support ndim-1 indexers in an AppendableTable" - ) + "currently only support ndim-1 indexers in an AppendableTable") # create according to the new data new_non_index_axes: list = [] @@ -3877,12 +3884,12 @@ def _create_axes( if table_exists: indexer = len(new_non_index_axes) # i.e. 0 exist_axis = self.non_index_axes[indexer][1] - if not array_equivalent(np.array(append_axis), np.array(exist_axis)): + if not array_equivalent(np.array(append_axis), + np.array(exist_axis)): # ahah! -> reindex - if array_equivalent( - np.array(sorted(append_axis)), np.array(sorted(exist_axis)) - ): + if array_equivalent(np.array(sorted(append_axis)), + np.array(sorted(exist_axis))): append_axis = exist_axis # the non_index_axes info @@ -3917,15 +3924,15 @@ def _create_axes( transposed = new_index.axis == 1 # figure out data_columns and get out blocks - data_columns = self.validate_data_columns( - data_columns, min_itemsize, new_non_index_axes - ) + data_columns = self.validate_data_columns(data_columns, min_itemsize, + new_non_index_axes) frame = self.get_object(obj, transposed)._consolidate() - blocks, blk_items = self._get_blocks_and_items( - frame, table_exists, new_non_index_axes, self.values_axes, data_columns - ) + blocks, blk_items = self._get_blocks_and_items(frame, table_exists, + new_non_index_axes, + self.values_axes, + data_columns) # add my values vaxes = [] @@ -3936,12 +3943,14 @@ def _create_axes( name = None # we have a data_column - if data_columns and len(b_items) == 1 and b_items[0] in data_columns: + if data_columns and len( + b_items) == 1 and b_items[0] in data_columns: klass = DataIndexableCol name = b_items[0] if not (name is None or isinstance(name, str)): # TODO: should the message here be more specifically non-str? - raise ValueError("cannot have non-object label DataIndexableCol") + raise ValueError( + "cannot have non-object label DataIndexableCol") # make sure that we match up the existing columns # if we have an existing table @@ -3953,8 +3962,7 @@ def _create_axes( except (IndexError, KeyError) as err: raise ValueError( f"Incompatible appended table [{blocks}]" - f"with existing table [{self.values_axes}]" - ) from err + f"with existing table [{self.values_axes}]") from err else: existing_col = None @@ -3981,7 +3989,8 @@ def _create_axes( if is_categorical_dtype(data_converted.dtype): ordered = data_converted.ordered meta = "category" - metadata = np.array(data_converted.categories, copy=False).ravel() + metadata = np.array(data_converted.categories, + copy=False).ravel() data, dtype_name = _get_data_and_dtype_name(data_converted) @@ -4086,8 +4095,7 @@ def get_blk_items(mgr): jitems = ",".join([pprint_thing(item) for item in items]) raise ValueError( f"cannot match existing table structure for [{jitems}] " - "on appending data" - ) from err + "on appending data") from err blocks = new_blocks blk_items = new_blk_items @@ -4145,7 +4153,8 @@ def process_filter(field, filt): takers = op(values, filt) return obj.loc(axis=axis_number)[takers] - raise ValueError(f"cannot find the field [{field}] for filtering!") + raise ValueError( + f"cannot find the field [{field}] for filtering!") obj = process_filter(field, filt) @@ -4182,9 +4191,10 @@ def create_description( return d - def read_coordinates( - self, where=None, start: int | None = None, stop: int | None = None - ): + def read_coordinates(self, + where=None, + start: int | None = None, + stop: int | None = None): """ select coordinates (row numbers) from a table; return the coordinates object @@ -4201,10 +4211,11 @@ def read_coordinates( coords = selection.select_coords() if selection.filter is not None: for field, op, filt in selection.filter.format(): - data = self.read_column( - field, start=coords.min(), stop=coords.max() + 1 - ) - coords = coords[op(data.iloc[coords - coords.min()], filt).values] + data = self.read_column(field, + start=coords.min(), + stop=coords.max() + 1) + coords = coords[op(data.iloc[coords - coords.min()], + filt).values] return Index(coords) @@ -4227,7 +4238,8 @@ def read_column( return False if where is not None: - raise TypeError("read_column does not currently accept a where clause") + raise TypeError( + "read_column does not currently accept a where clause") # find the axes for a in self.axes: @@ -4235,8 +4247,7 @@ def read_column( if not a.is_data_indexable: raise ValueError( f"column [{column}] can not be extracted individually; " - "it is not data indexable" - ) + "it is not data indexable") # column must be an indexable or a data column c = getattr(self.table.cols, column) @@ -4381,10 +4392,12 @@ def write_data(self, chunksize: int | None, dropna: bool = False): # transpose the values so first dimension is last # reshape the values if needed values = [a.take_data() for a in self.values_axes] - values = [v.transpose(np.roll(np.arange(v.ndim), v.ndim - 1)) for v in values] + values = [ + v.transpose(np.roll(np.arange(v.ndim), v.ndim - 1)) for v in values + ] bvalues = [] for i, v in enumerate(values): - new_shape = (nrows,) + self.dtype[names[nindexes + i]].shape + new_shape = (nrows, ) + self.dtype[names[nindexes + i]].shape bvalues.append(values[i].reshape(new_shape)) # write the chunks @@ -4450,7 +4463,10 @@ def write_data_chunk( self.table.append(rows) self.table.flush() - def delete(self, where=None, start: int | None = None, stop: int | None = None): + def delete(self, + where=None, + start: int | None = None, + stop: int | None = None): # delete all rows (and return the nrows) if where is None or not len(where): @@ -4500,9 +4516,8 @@ def delete(self, where=None, start: int | None = None, stop: int | None = None): pg = groups.pop() for g in reversed(groups): rows = sorted_series.take(range(g, pg)) - table.remove_rows( - start=rows[rows.index[0]], stop=rows[rows.index[-1]] + 1 - ) + table.remove_rows(start=rows[rows.index[0]], + stop=rows[rows.index[-1]] + 1) pg = g self.table.flush() @@ -4547,13 +4562,12 @@ def read( result = self._read_axes(where=where, start=start, stop=stop) - info = ( - self.info.get(self.non_index_axes[0][0], {}) - if len(self.non_index_axes) - else {} - ) + info = (self.info.get(self.non_index_axes[0][0], {}) + if len(self.non_index_axes) else {}) - inds = [i for i, ax in enumerate(self.axes) if ax is self.index_axes[0]] + inds = [ + i for i, ax in enumerate(self.axes) if ax is self.index_axes[0] + ] assert len(inds) == 1 ind = inds[0] @@ -4595,7 +4609,9 @@ def read( df = DataFrame(values, columns=cols_, index=index_) else: # Categorical - df = DataFrame._from_arrays([values], columns=cols_, index=index_) + df = DataFrame._from_arrays([values], + columns=cols_, + index=index_) assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype) frames.append(df) @@ -4632,7 +4648,9 @@ def write(self, obj, data_columns=None, **kwargs): if not isinstance(obj, DataFrame): name = obj.name or "values" obj = obj.to_frame(name) - return super().write(obj=obj, data_columns=obj.columns.tolist(), **kwargs) + return super().write(obj=obj, + data_columns=obj.columns.tolist(), + **kwargs) def read( self, @@ -4701,7 +4719,9 @@ def get_attrs(self): self.levels = [] self.index_axes = [a for a in self.indexables if a.is_an_indexable] - self.values_axes = [a for a in self.indexables if not a.is_an_indexable] + self.values_axes = [ + a for a in self.indexables if not a.is_an_indexable + ] self.data_columns = [a.name for a in self.values_axes] @cache_readonly @@ -4714,11 +4734,14 @@ def indexables(self): # the index columns is just a simple index md = self.read_metadata("index") meta = "category" if md is not None else None - index_col = GenericIndexCol( - name="index", axis=0, table=self.table, meta=meta, metadata=md - ) + index_col = GenericIndexCol(name="index", + axis=0, + table=self.table, + meta=meta, + metadata=md) - _indexables: list[GenericIndexCol | GenericDataIndexableCol] = [index_col] + _indexables: list[GenericIndexCol + | GenericDataIndexableCol] = [index_col] for i, n in enumerate(d._v_names): assert isinstance(n, str) @@ -4779,14 +4802,18 @@ def read( df = df.set_index(self.levels) # remove names for 'level_%d' - df.index = df.index.set_names( - [None if self._re_levels.search(name) else name for name in df.index.names] - ) + df.index = df.index.set_names([ + None if self._re_levels.search(name) else name + for name in df.index.names + ]) return df -def _reindex_axis(obj: DataFrame, axis: int, labels: Index, other=None) -> DataFrame: +def _reindex_axis(obj: DataFrame, + axis: int, + labels: Index, + other=None) -> DataFrame: ax = obj._get_axis(axis) labels = ensure_index(labels) @@ -4856,13 +4883,15 @@ def _set_tz( return values # type: ignore[return-value] -def _convert_index(name: str, index: Index, encoding: str, errors: str) -> IndexCol: +def _convert_index(name: str, index: Index, encoding: str, + errors: str) -> IndexCol: assert isinstance(name, str) index_name = index.name # error: Argument 1 to "_get_data_and_dtype_name" has incompatible type "Index"; # expected "Union[ExtensionArray, ndarray]" - converted, dtype_name = _get_data_and_dtype_name(index) # type: ignore[arg-type] + converted, dtype_name = _get_data_and_dtype_name( + index) # type: ignore[arg-type] kind = _dtype_to_kind(dtype_name) atom = DataIndexableCol._get_atom(converted) @@ -4891,9 +4920,11 @@ def _convert_index(name: str, index: Index, encoding: str, errors: str) -> Index if inferred_type == "date": converted = np.asarray([v.toordinal() for v in values], dtype=np.int32) - return IndexCol( - name, converted, "date", _tables().Time32Col(), index_name=index_name - ) + return IndexCol(name, + converted, + "date", + _tables().Time32Col(), + index_name=index_name) elif inferred_type == "string": converted = _convert_string_array(values, encoding, errors) @@ -4907,9 +4938,11 @@ def _convert_index(name: str, index: Index, encoding: str, errors: str) -> Index ) elif inferred_type in ["integer", "floating"]: - return IndexCol( - name, values=converted, kind=kind, typ=atom, index_name=index_name - ) + return IndexCol(name, + values=converted, + kind=kind, + typ=atom, + index_name=index_name) else: assert isinstance(converted, np.ndarray) and converted.dtype == object assert kind == "object", kind @@ -4917,7 +4950,8 @@ def _convert_index(name: str, index: Index, encoding: str, errors: str) -> Index return IndexCol(name, converted, kind, atom, index_name=index_name) -def _unconvert_index(data, kind: str, encoding: str, errors: str) -> np.ndarray | Index: +def _unconvert_index(data, kind: str, encoding: str, + errors: str) -> np.ndarray | Index: index: Index | np.ndarray if kind == "datetime64": @@ -4926,15 +4960,18 @@ def _unconvert_index(data, kind: str, encoding: str, errors: str) -> np.ndarray index = TimedeltaIndex(data) elif kind == "date": try: - index = np.asarray([date.fromordinal(v) for v in data], dtype=object) + index = np.asarray([date.fromordinal(v) for v in data], + dtype=object) except (ValueError): - index = np.asarray([date.fromtimestamp(v) for v in data], dtype=object) + index = np.asarray([date.fromtimestamp(v) for v in data], + dtype=object) elif kind in ("integer", "float"): index = np.asarray(data) elif kind in ("string"): - index = _unconvert_string_array( - data, nan_rep=None, encoding=encoding, errors=errors - ) + index = _unconvert_string_array(data, + nan_rep=None, + encoding=encoding, + errors=errors) elif kind == "object": index = np.asarray(data[0]) else: # pragma: no cover @@ -4967,8 +5004,7 @@ def _maybe_convert_for_string_atom( # after GH#8260 # this only would be hit for a multi-timezone dtype which is an error raise TypeError( - "too many timezones in this block, create separate data columns" - ) + "too many timezones in this block, create separate data columns") elif not (inferred_type == "string" or dtype_name == "object"): return bvalues @@ -4990,21 +5026,23 @@ def _maybe_convert_for_string_atom( col = data[i] inferred_type = lib.infer_dtype(col, skipna=False) if inferred_type != "string": - error_column_label = columns[i] if len(columns) > i else f"No.{i}" + error_column_label = columns[i] if len( + columns) > i else f"No.{i}" raise TypeError( f"Cannot serialize the column [{error_column_label}]\n" f"because its data contents are not [string] but " - f"[{inferred_type}] object dtype" - ) + f"[{inferred_type}] object dtype") # itemsize is the maximum length of a string (along any dimension) - data_converted = _convert_string_array(data, encoding, errors).reshape(data.shape) + data_converted = _convert_string_array(data, encoding, + errors).reshape(data.shape) itemsize = data_converted.itemsize # specified min_itemsize? if isinstance(min_itemsize, dict): - min_itemsize = int(min_itemsize.get(name) or min_itemsize.get("values") or 0) + min_itemsize = int( + min_itemsize.get(name) or min_itemsize.get("values") or 0) itemsize = max(min_itemsize or 0, itemsize) # check for column in the values conflicts @@ -5017,7 +5055,8 @@ def _maybe_convert_for_string_atom( return data_converted -def _convert_string_array(data: np.ndarray, encoding: str, errors: str) -> np.ndarray: +def _convert_string_array(data: np.ndarray, encoding: str, + errors: str) -> np.ndarray: """ Take a string-like that is object dtype and coerce to a fixed size string type. @@ -5034,11 +5073,8 @@ def _convert_string_array(data: np.ndarray, encoding: str, errors: str) -> np.nd """ # encode if needed if len(data): - data = ( - Series(data.ravel()) - .str.encode(encoding, errors) - ._values.reshape(data.shape) - ) + data = (Series(data.ravel()).str.encode( + encoding, errors)._values.reshape(data.shape)) # create the sized dtype ensured = ensure_object(data.ravel()) @@ -5048,9 +5084,8 @@ def _convert_string_array(data: np.ndarray, encoding: str, errors: str) -> np.nd return data -def _unconvert_string_array( - data: np.ndarray, nan_rep, encoding: str, errors: str -) -> np.ndarray: +def _unconvert_string_array(data: np.ndarray, nan_rep, encoding: str, + errors: str) -> np.ndarray: """ Inverse of _convert_string_array. @@ -5087,7 +5122,8 @@ def _unconvert_string_array( return data.reshape(shape) -def _maybe_convert(values: np.ndarray, val_kind: str, encoding: str, errors: str): +def _maybe_convert(values: np.ndarray, val_kind: str, encoding: str, + errors: str): assert isinstance(val_kind, str), type(val_kind) if _need_convert(val_kind): conv = _get_converter(val_kind, encoding, errors) @@ -5100,8 +5136,7 @@ def _get_converter(kind: str, encoding: str, errors: str): return lambda x: np.asarray(x, dtype="M8[ns]") elif kind == "string": return lambda x: _unconvert_string_array( - x, nan_rep=None, encoding=encoding, errors=errors - ) + x, nan_rep=None, encoding=encoding, errors=errors) else: # pragma: no cover raise ValueError(f"invalid kind {kind}") @@ -5127,7 +5162,8 @@ def _maybe_adjust_name(name: str, version: Sequence[int]) -> str: str """ if isinstance(version, str) or len(version) < 3: - raise ValueError("Version is incorrect, expected sequence of 3 integers.") + raise ValueError( + "Version is incorrect, expected sequence of 3 integers.") if version[0] == 0 and version[1] <= 10 and version[2] == 0: m = re.search(r"values_block_(\d+)", name) @@ -5235,9 +5271,10 @@ def __init__( stop = self.table.nrows self.coordinates = np.arange(start, stop)[where] elif issubclass(where.dtype.type, np.integer): - if (self.start is not None and (where < self.start).any()) or ( - self.stop is not None and (where >= self.stop).any() - ): + if (self.start is not None and + (where < self.start).any()) or ( + self.stop is not None and + (where >= self.stop).any()): raise ValueError( "where must have index locations >= start and < stop" ) @@ -5258,20 +5295,20 @@ def generate(self, where): q = self.table.queryables() try: - return PyTablesExpr(where, queryables=q, encoding=self.table.encoding) + return PyTablesExpr(where, + queryables=q, + encoding=self.table.encoding) except NameError as err: # raise a nice message, suggesting that the user should use # data_columns qkeys = ",".join(q.keys()) - msg = dedent( - f"""\ + msg = dedent(f"""\ The passed where expression: {where} contains an invalid variable reference all of the variable references must be a reference to an axis (e.g. 'index' or 'columns'), or a data_column The currently defined references are: {qkeys} - """ - ) + """) raise ValueError(msg) from err def select(self): @@ -5279,9 +5316,9 @@ def select(self): generate the selection """ if self.condition is not None: - return self.table.table.read_where( - self.condition.format(), start=self.start, stop=self.stop - ) + return self.table.table.read_where(self.condition.format(), + start=self.start, + stop=self.stop) elif self.coordinates is not None: return self.table.table.read_coordinates(self.coordinates) return self.table.table.read(start=self.start, stop=self.stop) @@ -5302,9 +5339,10 @@ def select_coords(self): stop += nrows if self.condition is not None: - return self.table.table.get_where_list( - self.condition.format(), start=start, stop=stop, sort=True - ) + return self.table.table.get_where_list(self.condition.format(), + start=start, + stop=stop, + sort=True) elif self.coordinates is not None: return self.coordinates diff --git a/pandas/io/sas/sas7bdat.py b/pandas/io/sas/sas7bdat.py index ddde2c9f8644c..95a5732e8d7fc 100644 --- a/pandas/io/sas/sas7bdat.py +++ b/pandas/io/sas/sas7bdat.py @@ -185,7 +185,8 @@ def __init__( self.column_formats: list[str] = [] self.columns: list[_Column] = [] - self._current_page_data_subheader_pointers: list[_SubheaderPointer] = [] + self._current_page_data_subheader_pointers: list[ + _SubheaderPointer] = [] self._cached_page = None self._column_data_lengths: list[int] = [] self._column_data_offsets: list[int] = [] @@ -229,7 +230,7 @@ def _get_properties(self) -> None: # Check magic number self._path_or_buf.seek(0) self._cached_page = self._path_or_buf.read(288) - if self._cached_page[0 : len(const.magic)] != const.magic: + if self._cached_page[0:len(const.magic)] != const.magic: raise ValueError("magic number mismatch (not a SAS file?)") # Get alignment information @@ -252,7 +253,8 @@ def _get_properties(self) -> None: total_align = align1 + align2 # Get endianness information - buf = self._read_bytes(const.endianness_offset, const.endianness_length) + buf = self._read_bytes(const.endianness_offset, + const.endianness_length) if buf == b"\x01": self.byte_order = "<" else: @@ -277,85 +279,74 @@ def _get_properties(self) -> None: buf = self._read_bytes(const.dataset_offset, const.dataset_length) self.name = buf.rstrip(b"\x00 ") if self.convert_header_text: - self.name = self.name.decode(self.encoding or self.default_encoding) + self.name = self.name.decode(self.encoding + or self.default_encoding) buf = self._read_bytes(const.file_type_offset, const.file_type_length) self.file_type = buf.rstrip(b"\x00 ") if self.convert_header_text: - self.file_type = self.file_type.decode( - self.encoding or self.default_encoding - ) + self.file_type = self.file_type.decode(self.encoding + or self.default_encoding) # Timestamp is epoch 01/01/1960 epoch = datetime(1960, 1, 1) - x = self._read_float( - const.date_created_offset + align1, const.date_created_length - ) + x = self._read_float(const.date_created_offset + align1, + const.date_created_length) self.date_created = epoch + pd.to_timedelta(x, unit="s") - x = self._read_float( - const.date_modified_offset + align1, const.date_modified_length - ) + x = self._read_float(const.date_modified_offset + align1, + const.date_modified_length) self.date_modified = epoch + pd.to_timedelta(x, unit="s") - self.header_length = self._read_int( - const.header_size_offset + align1, const.header_size_length - ) + self.header_length = self._read_int(const.header_size_offset + align1, + const.header_size_length) # Read the rest of the header into cached_page. buf = self._path_or_buf.read(self.header_length - 288) self._cached_page += buf # error: Argument 1 to "len" has incompatible type "Optional[bytes]"; # expected "Sized" - if len(self._cached_page) != self.header_length: # type: ignore[arg-type] + if len(self._cached_page + ) != self.header_length: # type: ignore[arg-type] raise ValueError("The SAS7BDAT file appears to be truncated.") - self._page_length = self._read_int( - const.page_size_offset + align1, const.page_size_length - ) - self._page_count = self._read_int( - const.page_count_offset + align1, const.page_count_length - ) + self._page_length = self._read_int(const.page_size_offset + align1, + const.page_size_length) + self._page_count = self._read_int(const.page_count_offset + align1, + const.page_count_length) - buf = self._read_bytes( - const.sas_release_offset + total_align, const.sas_release_length - ) + buf = self._read_bytes(const.sas_release_offset + total_align, + const.sas_release_length) self.sas_release = buf.rstrip(b"\x00 ") if self.convert_header_text: self.sas_release = self.sas_release.decode( - self.encoding or self.default_encoding - ) + self.encoding or self.default_encoding) - buf = self._read_bytes( - const.sas_server_type_offset + total_align, const.sas_server_type_length - ) + buf = self._read_bytes(const.sas_server_type_offset + total_align, + const.sas_server_type_length) self.server_type = buf.rstrip(b"\x00 ") if self.convert_header_text: self.server_type = self.server_type.decode( - self.encoding or self.default_encoding - ) + self.encoding or self.default_encoding) - buf = self._read_bytes( - const.os_version_number_offset + total_align, const.os_version_number_length - ) + buf = self._read_bytes(const.os_version_number_offset + total_align, + const.os_version_number_length) self.os_version = buf.rstrip(b"\x00 ") if self.convert_header_text: - self.os_version = self.os_version.decode( - self.encoding or self.default_encoding - ) + self.os_version = self.os_version.decode(self.encoding + or self.default_encoding) - buf = self._read_bytes(const.os_name_offset + total_align, const.os_name_length) + buf = self._read_bytes(const.os_name_offset + total_align, + const.os_name_length) buf = buf.rstrip(b"\x00 ") if len(buf) > 0: self.os_name = buf.decode(self.encoding or self.default_encoding) else: - buf = self._read_bytes( - const.os_maker_offset + total_align, const.os_maker_length - ) + buf = self._read_bytes(const.os_maker_offset + total_align, + const.os_maker_length) self.os_name = buf.rstrip(b"\x00 ") if self.convert_header_text: - self.os_name = self.os_name.decode( - self.encoding or self.default_encoding - ) + self.os_name = self.os_name.decode(self.encoding + or self.default_encoding) def __next__(self): da = self.read(nrows=self.chunksize or 1) @@ -396,7 +387,7 @@ def _read_bytes(self, offset: int, length: int): if offset + length > len(self._cached_page): self.close() raise ValueError("The cached page is too small.") - return self._cached_page[offset : offset + length] + return self._cached_page[offset:offset + length] def _parse_metadata(self) -> None: done = False @@ -405,7 +396,8 @@ def _parse_metadata(self) -> None: if len(self._cached_page) <= 0: break if len(self._cached_page) != self._page_length: - raise ValueError("Failed to read a meta data page from the SAS file.") + raise ValueError( + "Failed to read a meta data page from the SAS file.") done = self._process_page_meta() def _process_page_meta(self) -> bool: @@ -415,41 +407,38 @@ def _process_page_meta(self) -> bool: self._process_page_metadata() is_data_page = self._current_page_type & const.page_data_type is_mix_page = self._current_page_type in const.page_mix_types - return bool( - is_data_page - or is_mix_page - or self._current_page_data_subheader_pointers != [] - ) + return bool(is_data_page or is_mix_page + or self._current_page_data_subheader_pointers != []) def _read_page_header(self): bit_offset = self._page_bit_offset tx = const.page_type_offset + bit_offset self._current_page_type = self._read_int(tx, const.page_type_length) tx = const.block_count_offset + bit_offset - self._current_page_block_count = self._read_int(tx, const.block_count_length) + self._current_page_block_count = self._read_int( + tx, const.block_count_length) tx = const.subheader_count_offset + bit_offset self._current_page_subheaders_count = self._read_int( - tx, const.subheader_count_length - ) + tx, const.subheader_count_length) def _process_page_metadata(self) -> None: bit_offset = self._page_bit_offset for i in range(self._current_page_subheaders_count): pointer = self._process_subheader_pointers( - const.subheader_pointers_offset + bit_offset, i - ) + const.subheader_pointers_offset + bit_offset, i) if pointer.length == 0: continue if pointer.compression == const.truncated_subheader_id: continue - subheader_signature = self._read_subheader_signature(pointer.offset) + subheader_signature = self._read_subheader_signature( + pointer.offset) subheader_index = self._get_subheader_index( - subheader_signature, pointer.compression, pointer.ptype - ) + subheader_signature, pointer.compression, pointer.ptype) self._process_subheader(subheader_index, pointer) - def _get_subheader_index(self, signature: bytes, compression, ptype) -> int: + def _get_subheader_index(self, signature: bytes, compression, + ptype) -> int: # TODO: return here could be made an enum index = const.subheader_signature_to_index.get(signature) if index is None: @@ -463,8 +452,8 @@ def _get_subheader_index(self, signature: bytes, compression, ptype) -> int: return index def _process_subheader_pointers( - self, offset: int, subheader_pointer_index: int - ) -> _SubheaderPointer: + self, offset: int, + subheader_pointer_index: int) -> _SubheaderPointer: subheader_pointer_length = self._subheader_pointer_length total_offset = offset + subheader_pointer_length * subheader_pointer_index @@ -480,9 +469,8 @@ def _process_subheader_pointers( subheader_type = self._read_int(total_offset, 1) - x = _SubheaderPointer( - subheader_offset, subheader_length, subheader_compression, subheader_type - ) + x = _SubheaderPointer(subheader_offset, subheader_length, + subheader_compression, subheader_type) return x @@ -490,9 +478,8 @@ def _read_subheader_signature(self, offset: int) -> bytes: subheader_signature = self._read_bytes(offset, self._int_length) return subheader_signature - def _process_subheader( - self, subheader_index: int, pointer: _SubheaderPointer - ) -> None: + def _process_subheader(self, subheader_index: int, + pointer: _SubheaderPointer) -> None: offset = pointer.offset length = pointer.length @@ -533,17 +520,13 @@ def _process_rowsize_subheader(self, offset: int, length: int) -> None: lcp_offset += 378 self.row_length = self._read_int( - offset + const.row_length_offset_multiplier * int_len, int_len - ) + offset + const.row_length_offset_multiplier * int_len, int_len) self.row_count = self._read_int( - offset + const.row_count_offset_multiplier * int_len, int_len - ) + offset + const.row_count_offset_multiplier * int_len, int_len) self.col_count_p1 = self._read_int( - offset + const.col_count_p1_multiplier * int_len, int_len - ) + offset + const.col_count_p1_multiplier * int_len, int_len) self.col_count_p2 = self._read_int( - offset + const.col_count_p2_multiplier * int_len, int_len - ) + offset + const.col_count_p2_multiplier * int_len, int_len) mx = const.row_count_on_mix_page_offset_multiplier * int_len self._mix_page_row_count = self._read_int(offset + mx, int_len) self._lcs = self._read_int(lcs_offset, 2) @@ -554,10 +537,8 @@ def _process_columnsize_subheader(self, offset: int, length: int) -> None: offset += int_len self.column_count = self._read_int(offset, int_len) if self.col_count_p1 + self.col_count_p2 != self.column_count: - print( - f"Warning: column count mismatch ({self.col_count_p1} + " - f"{self.col_count_p2} != {self.column_count})\n" - ) + print(f"Warning: column count mismatch ({self.col_count_p1} + " + f"{self.col_count_p2} != {self.column_count})\n") # Unknown purpose def _process_subheader_counts(self, offset: int, length: int) -> None: @@ -595,74 +576,62 @@ def _process_columntext_subheader(self, offset: int, length: int) -> None: if self.U64: offset1 += 4 buf = self._read_bytes(offset1, self._lcp) - self.creator_proc = buf[0 : self._lcp] + self.creator_proc = buf[0:self._lcp] elif compression_literal == const.rle_compression: offset1 = offset + 40 if self.U64: offset1 += 4 buf = self._read_bytes(offset1, self._lcp) - self.creator_proc = buf[0 : self._lcp] + self.creator_proc = buf[0:self._lcp] elif self._lcs > 0: self._lcp = 0 offset1 = offset + 16 if self.U64: offset1 += 4 buf = self._read_bytes(offset1, self._lcs) - self.creator_proc = buf[0 : self._lcp] + self.creator_proc = buf[0:self._lcp] if self.convert_header_text: if hasattr(self, "creator_proc"): self.creator_proc = self.creator_proc.decode( - self.encoding or self.default_encoding - ) + self.encoding or self.default_encoding) def _process_columnname_subheader(self, offset: int, length: int) -> None: int_len = self._int_length offset += int_len column_name_pointers_count = (length - 2 * int_len - 12) // 8 for i in range(column_name_pointers_count): - text_subheader = ( - offset - + const.column_name_pointer_length * (i + 1) - + const.column_name_text_subheader_offset - ) - col_name_offset = ( - offset - + const.column_name_pointer_length * (i + 1) - + const.column_name_offset_offset - ) - col_name_length = ( - offset - + const.column_name_pointer_length * (i + 1) - + const.column_name_length_offset - ) - - idx = self._read_int( - text_subheader, const.column_name_text_subheader_length - ) - col_offset = self._read_int( - col_name_offset, const.column_name_offset_length - ) - col_len = self._read_int(col_name_length, const.column_name_length_length) + text_subheader = (offset + const.column_name_pointer_length * + (i + 1) + + const.column_name_text_subheader_offset) + col_name_offset = (offset + const.column_name_pointer_length * + (i + 1) + const.column_name_offset_offset) + col_name_length = (offset + const.column_name_pointer_length * + (i + 1) + const.column_name_length_offset) + + idx = self._read_int(text_subheader, + const.column_name_text_subheader_length) + col_offset = self._read_int(col_name_offset, + const.column_name_offset_length) + col_len = self._read_int(col_name_length, + const.column_name_length_length) name_str = self.column_names_strings[idx] - self.column_names.append(name_str[col_offset : col_offset + col_len]) + self.column_names.append(name_str[col_offset:col_offset + col_len]) - def _process_columnattributes_subheader(self, offset: int, length: int) -> None: + def _process_columnattributes_subheader(self, offset: int, + length: int) -> None: int_len = self._int_length - column_attributes_vectors_count = (length - 2 * int_len - 12) // (int_len + 8) + column_attributes_vectors_count = (length - 2 * int_len - + 12) // (int_len + 8) for i in range(column_attributes_vectors_count): - col_data_offset = ( - offset + int_len + const.column_data_offset_offset + i * (int_len + 8) - ) - col_data_len = ( - offset - + 2 * int_len - + const.column_data_length_offset - + i * (int_len + 8) - ) - col_types = ( - offset + 2 * int_len + const.column_type_offset + i * (int_len + 8) - ) + col_data_offset = (offset + int_len + + const.column_data_offset_offset + i * + (int_len + 8)) + col_data_len = (offset + 2 * int_len + + const.column_data_length_offset + i * + (int_len + 8)) + col_types = (offset + 2 * int_len + const.column_type_offset + i * + (int_len + 8)) x = self._read_int(col_data_offset, int_len) self._column_data_offsets.append(x) @@ -680,38 +649,39 @@ def _process_columnlist_subheader(self, offset: int, length: int) -> None: def _process_format_subheader(self, offset: int, length: int) -> None: int_len = self._int_length text_subheader_format = ( - offset + const.column_format_text_subheader_index_offset + 3 * int_len - ) + offset + const.column_format_text_subheader_index_offset + + 3 * int_len) col_format_offset = offset + const.column_format_offset_offset + 3 * int_len col_format_len = offset + const.column_format_length_offset + 3 * int_len text_subheader_label = ( - offset + const.column_label_text_subheader_index_offset + 3 * int_len - ) + offset + const.column_label_text_subheader_index_offset + + 3 * int_len) col_label_offset = offset + const.column_label_offset_offset + 3 * int_len col_label_len = offset + const.column_label_length_offset + 3 * int_len - x = self._read_int( - text_subheader_format, const.column_format_text_subheader_index_length - ) + x = self._read_int(text_subheader_format, + const.column_format_text_subheader_index_length) format_idx = min(x, len(self.column_names_strings) - 1) - format_start = self._read_int( - col_format_offset, const.column_format_offset_length - ) - format_len = self._read_int(col_format_len, const.column_format_length_length) + format_start = self._read_int(col_format_offset, + const.column_format_offset_length) + format_len = self._read_int(col_format_len, + const.column_format_length_length) label_idx = self._read_int( - text_subheader_label, const.column_label_text_subheader_index_length - ) + text_subheader_label, + const.column_label_text_subheader_index_length) label_idx = min(label_idx, len(self.column_names_strings) - 1) - label_start = self._read_int(col_label_offset, const.column_label_offset_length) - label_len = self._read_int(col_label_len, const.column_label_length_length) + label_start = self._read_int(col_label_offset, + const.column_label_offset_length) + label_len = self._read_int(col_label_len, + const.column_label_length_length) label_names = self.column_names_strings[label_idx] - column_label = label_names[label_start : label_start + label_len] + column_label = label_names[label_start:label_start + label_len] format_names = self.column_names_strings[format_idx] - column_format = format_names[format_start : format_start + format_len] + column_format = format_names[format_start:format_start + format_len] current_column_number = len(self.columns) col = _Column( @@ -769,8 +739,7 @@ def _read_next_page(self): self.close() msg = ( "failed to read complete page from file (read " - f"{len(self._cached_page):d} of {self._page_length:d} bytes)" - ) + f"{len(self._cached_page):d} of {self._page_length:d} bytes)") raise ValueError(msg) self._read_page_header() @@ -798,7 +767,8 @@ def _chunk_to_dataframe(self) -> DataFrame: name = self.column_names[j] if self._column_types[j] == b"d": - col_arr = self._byte_chunk[jb, :].view(dtype=self.byte_order + "d") + col_arr = self._byte_chunk[jb, :].view(dtype=self.byte_order + + "d") rslt[name] = pd.Series(col_arr, dtype=np.float64, index=ix) if self.convert_dates: if self.column_formats[j] in const.sas_date_formats: @@ -810,15 +780,15 @@ def _chunk_to_dataframe(self) -> DataFrame: rslt[name] = pd.Series(self._string_chunk[js, :], index=ix) if self.convert_text and (self.encoding is not None): rslt[name] = rslt[name].str.decode( - self.encoding or self.default_encoding - ) + self.encoding or self.default_encoding) if self.blank_missing: ii = rslt[name].str.len() == 0 rslt[name][ii] = np.nan js += 1 else: self.close() - raise ValueError(f"unknown column type {repr(self._column_types[j])}") + raise ValueError( + f"unknown column type {repr(self._column_types[j])}") df = DataFrame(rslt, columns=self.column_names, index=ix, copy=False) return df diff --git a/pandas/io/sql.py b/pandas/io/sql.py index e1af6c7288992..6bef529126172 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -78,9 +78,9 @@ def _process_parse_dates_argument(parse_dates): return parse_dates -def _handle_date_column( - col, utc: bool | None = None, format: str | dict[str, Any] | None = None -): +def _handle_date_column(col, + utc: bool | None = None, + format: str | dict[str, Any] | None = None): if isinstance(format, dict): # GH35185 Allow custom error values in parse_dates argument of # read_sql like functions. @@ -91,10 +91,8 @@ def _handle_date_column( else: # Allow passing of formatting string for integers # GH17855 - if format is None and ( - issubclass(col.dtype.type, np.floating) - or issubclass(col.dtype.type, np.integer) - ): + if format is None and (issubclass(col.dtype.type, np.floating) + or issubclass(col.dtype.type, np.integer)): format = "s" if format in ["D", "d", "h", "m", "s", "ms", "us", "ns"]: return to_datetime(col, errors="coerce", unit=format, utc=utc) @@ -136,7 +134,9 @@ def _wrap_result( dtype: DtypeArg | None = None, ): """Wrap result set of query in a DataFrame.""" - frame = DataFrame.from_records(data, columns=columns, coerce_float=coerce_float) + frame = DataFrame.from_records(data, + columns=columns, + coerce_float=coerce_float) if dtype: frame = frame.astype(dtype) @@ -680,8 +680,7 @@ def to_sql( frame = frame.to_frame() elif not isinstance(frame, DataFrame): raise NotImplementedError( - "'frame' argument should be either a Series or a DataFrame" - ) + "'frame' argument should be either a Series or a DataFrame") return pandas_sql.to_sql( frame, @@ -744,7 +743,8 @@ def pandasSQL_builder(con, schema: str | None = None): else: con = sqlalchemy.create_engine(con) - if sqlalchemy is not None and isinstance(con, sqlalchemy.engine.Connectable): + if sqlalchemy is not None and isinstance(con, + sqlalchemy.engine.Connectable): return SQLDatabase(con, schema=schema) warnings.warn( @@ -823,7 +823,8 @@ def create(self): elif self.if_exists == "append": pass else: - raise ValueError(f"'{self.if_exists}' is not valid for if_exists") + raise ValueError( + f"'{self.if_exists}' is not valid for if_exists") else: self._execute_create() @@ -866,7 +867,8 @@ def insert_data(self): try: temp.reset_index(inplace=True) except ValueError as err: - raise ValueError(f"duplicate name in index/columns: {err}") from err + raise ValueError( + f"duplicate name in index/columns: {err}") from err else: temp = self.frame @@ -897,9 +899,9 @@ def insert_data(self): return column_names, data_list - def insert( - self, chunksize: int | None = None, method: str | None = None - ) -> int | None: + def insert(self, + chunksize: int | None = None, + method: str | None = None) -> int | None: # set insert method if method is None: @@ -954,15 +956,15 @@ def _query_iterator( data = result.fetchmany(chunksize) if not data: if not has_read_data: - yield DataFrame.from_records( - [], columns=columns, coerce_float=coerce_float - ) + yield DataFrame.from_records([], + columns=columns, + coerce_float=coerce_float) break else: has_read_data = True - self.frame = DataFrame.from_records( - data, columns=columns, coerce_float=coerce_float - ) + self.frame = DataFrame.from_records(data, + columns=columns, + coerce_float=coerce_float) self._harmonize_columns(parse_dates=parse_dates) @@ -971,7 +973,11 @@ def _query_iterator( yield self.frame - def read(self, coerce_float=True, parse_dates=None, columns=None, chunksize=None): + def read(self, + coerce_float=True, + parse_dates=None, + columns=None, + chunksize=None): from sqlalchemy import select if columns is not None and len(columns) > 0: @@ -995,9 +1001,9 @@ def read(self, coerce_float=True, parse_dates=None, columns=None, chunksize=None ) else: data = result.fetchall() - self.frame = DataFrame.from_records( - data, columns=column_names, coerce_float=coerce_float - ) + self.frame = DataFrame.from_records(data, + columns=column_names, + coerce_float=coerce_float) self._harmonize_columns(parse_dates=parse_dates) @@ -1017,16 +1023,12 @@ def _index_name(self, index, index_label): if len(index_label) != nlevels: raise ValueError( "Length of 'index_label' should match number of " - f"levels, which is {nlevels}" - ) + f"levels, which is {nlevels}") else: return index_label # return the used column labels for the index columns - if ( - nlevels == 1 - and "index" not in self.frame.columns - and self.frame.index.name is None - ): + if (nlevels == 1 and "index" not in self.frame.columns + and self.frame.index.name is None): return ["index"] else: return com.fill_missing_names(self.frame.index.names) @@ -1046,10 +1048,9 @@ def _get_column_names_and_types(self, dtype_mapper): idx_type = dtype_mapper(self.frame.index._get_level_values(i)) column_names_and_types.append((str(idx_label), idx_type, True)) - column_names_and_types += [ - (str(self.frame.columns[i]), dtype_mapper(self.frame.iloc[:, i]), False) - for i in range(len(self.frame.columns)) - ] + column_names_and_types += [(str(self.frame.columns[i]), + dtype_mapper(self.frame.iloc[:, i]), False) + for i in range(len(self.frame.columns))] return column_names_and_types @@ -1061,7 +1062,8 @@ def _create_table_setup(self): ) from sqlalchemy.schema import MetaData - column_names_and_types = self._get_column_names_and_types(self._sqlalchemy_type) + column_names_and_types = self._get_column_names_and_types( + self._sqlalchemy_type) columns = [ Column(name, typ, index=is_index) @@ -1109,17 +1111,15 @@ def _harmonize_columns(self, parse_dates=None): fmt = parse_dates[col_name] except TypeError: fmt = None - self.frame[col_name] = _handle_date_column(df_col, format=fmt) + self.frame[col_name] = _handle_date_column(df_col, + format=fmt) continue # the type the dataframe column should have col_type = self._get_dtype(sql_col.type) - if ( - col_type is datetime - or col_type is date - or col_type is DatetimeTZDtype - ): + if (col_type is datetime or col_type is date + or col_type is DatetimeTZDtype): # Convert tz-aware Datetime SQL columns to UTC utc = col_type is DatetimeTZDtype self.frame[col_name] = _handle_date_column(df_col, utc=utc) @@ -1130,7 +1130,8 @@ def _harmonize_columns(self, parse_dates=None): elif len(df_col) == df_col.count(): # No NA values, can convert ints and bools if col_type is np.dtype("int64") or col_type is bool: - self.frame[col_name] = df_col.astype(col_type, copy=False) + self.frame[col_name] = df_col.astype(col_type, + copy=False) except KeyError: pass # this column not in results @@ -1191,7 +1192,8 @@ def _sqlalchemy_type(self, col): elif col.dtype.name.lower() in ("uint16", "int32"): return Integer elif col.dtype.name.lower() == "uint64": - raise ValueError("Unsigned 64 bit integer datatype is not supported") + raise ValueError( + "Unsigned 64 bit integer datatype is not supported") else: return BigInteger elif col_type == "boolean": @@ -1241,10 +1243,8 @@ class PandasSQL(PandasObject): """ def read_sql(self, *args, **kwargs): - raise ValueError( - "PandasSQL must be created with an SQLAlchemy " - "connectable or sqlite connection" - ) + raise ValueError("PandasSQL must be created with an SQLAlchemy " + "connectable or sqlite connection") def to_sql( self, @@ -1258,13 +1258,12 @@ def to_sql( dtype: DtypeArg | None = None, method=None, ) -> int | None: - raise ValueError( - "PandasSQL must be created with an SQLAlchemy " - "connectable or sqlite connection" - ) + raise ValueError("PandasSQL must be created with an SQLAlchemy " + "connectable or sqlite connection") class BaseEngine: + def insert_records( self, table: SQLTable, @@ -1284,10 +1283,10 @@ def insert_records( class SQLAlchemyEngine(BaseEngine): + def __init__(self): import_optional_dependency( - "sqlalchemy", extra="sqlalchemy is required for SQL support." - ) + "sqlalchemy", extra="sqlalchemy is required for SQL support.") def insert_records( self, @@ -1340,8 +1339,7 @@ def get_engine(engine: str) -> BaseEngine: "sqlalchemy is required for sql I/O " "support.\n" "Trying to import the above resulted in these errors:" - f"{error_msgs}" - ) + f"{error_msgs}") elif engine == "sqlalchemy": return SQLAlchemyEngine() @@ -1587,7 +1585,8 @@ def prep_table( # Type[str], Type[float], Type[int], Type[complex], Type[bool], # Type[object]]]]"; expected type "Union[ExtensionDtype, str, # dtype[Any], Type[object]]" - dtype = {col_name: dtype for col_name in frame} # type: ignore[misc] + dtype = {col_name: dtype + for col_name in frame} # type: ignore[misc] else: dtype = cast(dict, dtype) @@ -1598,7 +1597,8 @@ def prep_table( for col, my_type in dtype.items(): if not isinstance(to_instance(my_type), TypeEngine): - raise ValueError(f"The type of {col} is not a SQLAlchemy type") + raise ValueError( + f"The type of {col} is not a SQLAlchemy type") table = SQLTable( name, @@ -1629,14 +1629,14 @@ def check_case_sensitive( with self.connectable.connect() as conn: insp = inspect(conn) - table_names = insp.get_table_names(schema=schema or self.meta.schema) + table_names = insp.get_table_names( + schema=schema or self.meta.schema) if name not in table_names: msg = ( f"The provided table name '{name}' is not found exactly as " "such in the database after writing the table, possibly " "due to case sensitivity issues. Consider using lower " - "case table names." - ) + "case table names.") warnings.warn(msg, UserWarning) def to_sql( @@ -1745,9 +1745,10 @@ def get_table(self, table_name: str, schema: str | None = None): ) schema = schema or self.meta.schema - tbl = Table( - table_name, self.meta, autoload_with=self.connectable, schema=schema - ) + tbl = Table(table_name, + self.meta, + autoload_with=self.connectable, + schema=schema) for column in tbl.columns: if isinstance(column.type, Numeric): column.type.asdecimal = False @@ -1756,7 +1757,9 @@ def get_table(self, table_name: str, schema: str | None = None): def drop_table(self, table_name: str, schema: str | None = None): schema = schema or self.meta.schema if self.has_table(table_name, schema): - self.meta.reflect(bind=self.connectable, only=[table_name], schema=schema) + self.meta.reflect(bind=self.connectable, + only=[table_name], + schema=schema) self.get_table(table_name, schema).drop(bind=self.connectable) self.meta.clear() @@ -1798,7 +1801,8 @@ def _get_unicode_name(name): try: uname = str(name).encode("utf-8", "strict").decode("utf-8") except UnicodeError as err: - raise ValueError(f"Cannot convert identifier to UTF-8: '{name}'") from err + raise ValueError( + f"Cannot convert identifier to UTF-8: '{name}'") from err return uname @@ -1871,7 +1875,8 @@ def _execute_insert(self, conn, keys, data_iter) -> int: def _execute_insert_multi(self, conn, keys, data_iter) -> int: data_list = list(data_iter) flattened_data = [x for row in data_list for x in row] - conn.execute(self.insert_statement(num_rows=len(data_list)), flattened_data) + conn.execute(self.insert_statement(num_rows=len(data_list)), + flattened_data) return conn.rowcount def _create_table_setup(self): @@ -1880,11 +1885,13 @@ def _create_table_setup(self): structure of a DataFrame. The first entry will be a CREATE TABLE statement while the rest will be CREATE INDEX statements. """ - column_names_and_types = self._get_column_names_and_types(self._sql_type_name) + column_names_and_types = self._get_column_names_and_types( + self._sql_type_name) escape = _get_valid_sqlite_name create_tbl_stmts = [ - escape(cname) + " " + ctype for cname, ctype, _ in column_names_and_types + escape(cname) + " " + ctype + for cname, ctype, _ in column_names_and_types ] if self.keys is not None and len(self.keys): @@ -1894,34 +1901,26 @@ def _create_table_setup(self): keys = self.keys cnames_br = ", ".join([escape(c) for c in keys]) create_tbl_stmts.append( - f"CONSTRAINT {self.name}_pk PRIMARY KEY ({cnames_br})" - ) + f"CONSTRAINT {self.name}_pk PRIMARY KEY ({cnames_br})") if self.schema: schema_name = self.schema + "." else: schema_name = "" create_stmts = [ - "CREATE TABLE " - + schema_name - + escape(self.name) - + " (\n" - + ",\n ".join(create_tbl_stmts) - + "\n)" + "CREATE TABLE " + schema_name + escape(self.name) + " (\n" + + ",\n ".join(create_tbl_stmts) + "\n)" ] - ix_cols = [cname for cname, _, is_index in column_names_and_types if is_index] + ix_cols = [ + cname for cname, _, is_index in column_names_and_types if is_index + ] if len(ix_cols): cnames = "_".join(ix_cols) cnames_br = ",".join([escape(c) for c in ix_cols]) - create_stmts.append( - "CREATE INDEX " - + escape("ix_" + self.name + "_" + cnames) - + "ON " - + escape(self.name) - + " (" - + cnames_br - + ")" - ) + create_stmts.append("CREATE INDEX " + + escape("ix_" + self.name + "_" + cnames) + + "ON " + escape(self.name) + " (" + cnames_br + + ")") return create_stmts @@ -2022,9 +2021,9 @@ def _query_iterator( if not data: cursor.close() if not has_read_data: - yield DataFrame.from_records( - [], columns=columns, coerce_float=coerce_float - ) + yield DataFrame.from_records([], + columns=columns, + coerce_float=coerce_float) break else: has_read_data = True @@ -2141,7 +2140,8 @@ def to_sql( # Type[str], Type[float], Type[int], Type[complex], Type[bool], # Type[object]]]]"; expected type "Union[ExtensionDtype, str, # dtype[Any], Type[object]]" - dtype = {col_name: dtype for col_name in frame} # type: ignore[misc] + dtype = {col_name: dtype + for col_name in frame} # type: ignore[misc] else: dtype = cast(dict, dtype) @@ -2226,6 +2226,8 @@ def get_schema( .. versionadded:: 1.2.0 """ pandas_sql = pandasSQL_builder(con=con) - return pandas_sql._create_sql_schema( - frame, name, keys=keys, dtype=dtype, schema=schema - ) + return pandas_sql._create_sql_schema(frame, + name, + keys=keys, + dtype=dtype, + schema=schema) diff --git a/pandas/io/stata.py b/pandas/io/stata.py index f97af0ed73cb0..034aa9a781962 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -77,8 +77,7 @@ "Version of given Stata file is {version}. pandas supports importing " "versions 105, 108, 111 (Stata 7SE), 113 (Stata 8/9), " "114 (Stata 10/11), 115 (Stata 12), 117 (Stata 13), 118 (Stata 14/15/16)," - "and 119 (Stata 15/16, over 32,767 variables)." -) + "and 119 (Stata 15/16, over 32,767 variables).") _statafile_processing_params1 = """\ convert_dates : bool, default True @@ -210,10 +209,8 @@ {_reader_notes} """ - _date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"] - stata_epoch = datetime.datetime(1960, 1, 1) @@ -284,8 +281,8 @@ def convert_year_month_safe(year, month) -> Series: else: index = getattr(year, "index", None) return Series( - [datetime.datetime(y, m, 1) for y, m in zip(year, month)], index=index - ) + [datetime.datetime(y, m, 1) for y, m in zip(year, month)], + index=index) def convert_year_days_safe(year, days) -> Series: """ @@ -293,7 +290,8 @@ def convert_year_days_safe(year, days) -> Series: datetime or datetime64 Series """ if year.max() < (MAX_YEAR - 1) and year.min() > MIN_YEAR: - return to_datetime(year, format="%Y") + to_timedelta(days, unit="d") + return to_datetime(year, format="%Y") + to_timedelta(days, + unit="d") else: index = getattr(year, "index", None) value = [ @@ -316,7 +314,8 @@ def convert_delta_safe(base, deltas, unit) -> Series: elif unit == "ms": if deltas.max() > MAX_MS_DELTA or deltas.min() < MIN_MS_DELTA: values = [ - base + relativedelta(microseconds=(int(d) * 1000)) for d in deltas + base + relativedelta(microseconds=(int(d) * 1000)) + for d in deltas ] return Series(values, index=index) else: @@ -341,7 +340,8 @@ def convert_delta_safe(base, deltas, unit) -> Series: conv_dates = convert_delta_safe(base, ms, "ms") elif fmt.startswith(("%tC", "tC")): - warnings.warn("Encountered %tC format. Leaving in Stata Internal Format.") + warnings.warn( + "Encountered %tC format. Leaving in Stata Internal Format.") conv_dates = Series(dates, dtype=object) if has_bad_values: conv_dates[bad_locs] = NaT @@ -403,15 +403,15 @@ def parse_dates_safe(dates, delta=False, year=False, days=False): if is_datetime64_dtype(dates.dtype): if delta: time_delta = dates - stata_epoch - d["delta"] = time_delta._values.view(np.int64) // 1000 # microseconds + d["delta"] = time_delta._values.view( + np.int64) // 1000 # microseconds if days or year: date_index = DatetimeIndex(dates) d["year"] = date_index._data.year d["month"] = date_index._data.month if days: days_in_ns = dates.view(np.int64) - to_datetime( - d["year"], format="%Y" - ).view(np.int64) + d["year"], format="%Y").view(np.int64) d["days"] = days_in_ns // NS_PER_DAY elif infer_dtype(dates, skipna=False) == "datetime": @@ -435,10 +435,8 @@ def g(x: datetime.datetime) -> int: v = np.vectorize(g) d["days"] = v(dates) else: - raise ValueError( - "Columns containing dates must contain either " - "datetime64, datetime.datetime or null values." - ) + raise ValueError("Columns containing dates must contain either " + "datetime64, datetime.datetime or null values.") return DataFrame(d, index=index) @@ -471,7 +469,8 @@ def g(x: datetime.datetime) -> int: conv_dates = 4 * (d.year - stata_epoch.year) + (d.month - 1) // 3 elif fmt in ["%th", "th"]: d = parse_dates_safe(dates, year=True) - conv_dates = 2 * (d.year - stata_epoch.year) + (d.month > 6).astype(int) + conv_dates = 2 * (d.year - stata_epoch.year) + (d.month > + 6).astype(int) elif fmt in ["%ty", "ty"]: d = parse_dates_safe(dates, year=True) conv_dates = d.year @@ -585,7 +584,8 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame: for col in data: # Cast from unsupported types to supported types - is_nullable_int = isinstance(data[col].dtype, (IntegerDtype, BooleanDtype)) + is_nullable_int = isinstance(data[col].dtype, + (IntegerDtype, BooleanDtype)) orig = data[col] # We need to find orig_missing before altering data below orig_missing = orig.isna() @@ -601,7 +601,8 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame: for c_data in conversion_data: if dtype == c_data[0]: # Value of type variable "_IntType" of "iinfo" cannot be "object" - if data[col].max() <= np.iinfo(c_data[1]).max: # type: ignore[type-var] + if data[col].max() <= np.iinfo( + c_data[1]).max: # type: ignore[type-var] dtype = c_data[1] else: dtype = c_data[2] @@ -619,7 +620,8 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame: if data[col].max() > 32740 or data[col].min() < -32767: data[col] = data[col].astype(np.int32) elif dtype == np.int64: - if data[col].max() <= 2147483620 and data[col].min() >= -2147483647: + if data[col].max() <= 2147483620 and data[col].min( + ) >= -2147483647: data[col] = data[col].astype(np.int32) else: data[col] = data[col].astype(np.float64) @@ -629,8 +631,7 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame: if np.isinf(data[col]).any(): raise ValueError( f"Column {col} contains infinity or -infinity" - "which is outside the range supported by Stata." - ) + "which is outside the range supported by Stata.") value = data[col].max() if dtype == np.float32 and value > float32_max: data[col] = data[col].astype(np.float64) @@ -638,12 +639,12 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame: if value > float64_max: raise ValueError( f"Column {col} has a maximum value ({value}) outside the range " - f"supported by Stata ({float64_max})" - ) + f"supported by Stata ({float64_max})") if is_nullable_int: if orig_missing.any(): # Replace missing by Stata sentinel value - sentinel = StataMissingValue.BASE_MISSING_VALUES[data[col].dtype.name] + sentinel = StataMissingValue.BASE_MISSING_VALUES[ + data[col].dtype.name] data.loc[orig_missing, col] = sentinel if ws: warnings.warn(ws, PossiblePrecisionLoss) @@ -671,8 +672,7 @@ def __init__(self, catarray: Series, encoding: str = "latin-1"): self._encoding = encoding categories = catarray.cat.categories self.value_labels: list[tuple[int | float, str]] = list( - zip(np.arange(len(categories)), categories) - ) + zip(np.arange(len(categories)), categories)) self.value_labels.sort(key=lambda x: x[0]) self._prepare_value_labels() @@ -710,8 +710,7 @@ def _prepare_value_labels(self): if self.text_len > 32000: raise ValueError( "Stata value labels for a single variable must " - "have a combined length less than 32,000 characters." - ) + "have a combined length less than 32,000 characters.") # Ensure int32 self.off = np.array(offsets, dtype=np.int32) @@ -799,9 +798,9 @@ def __init__( self.labname = labname self._encoding = encoding - self.value_labels: list[tuple[int | float, str]] = sorted( - value_labels.items(), key=lambda x: x[0] - ) + self.value_labels: list[tuple[int | float, + str]] = sorted(value_labels.items(), + key=lambda x: x[0]) self._prepare_value_labels() @@ -916,11 +915,8 @@ def __repr__(self) -> str: return f"{type(self)}({self})" def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, type(self)) - and self.string == other.string - and self.value == other.value - ) + return (isinstance(other, type(self)) and self.string == other.string + and self.value == other.value) @classmethod def get_base_missing_value(cls, dtype: np.dtype) -> int | float: @@ -940,6 +936,7 @@ def get_base_missing_value(cls, dtype: np.dtype) -> int | float: class StataParser: + def __init__(self): # type code. @@ -958,15 +955,15 @@ def __init__(self): # with a label, but the underlying variable is -127 to 100 # we're going to drop the label and cast to int self.DTYPE_MAP = dict( - list(zip(range(1, 245), [np.dtype("a" + str(i)) for i in range(1, 245)])) - + [ - (251, np.dtype(np.int8)), - (252, np.dtype(np.int16)), - (253, np.dtype(np.int32)), - (254, np.dtype(np.float32)), - (255, np.dtype(np.float64)), - ] - ) + list( + zip(range(1, 245), + [np.dtype("a" + str(i)) for i in range(1, 245)])) + [ + (251, np.dtype(np.int8)), + (252, np.dtype(np.int16)), + (253, np.dtype(np.int32)), + (254, np.dtype(np.float32)), + (255, np.dtype(np.float64)), + ]) self.DTYPE_MAP_XML = { 32768: np.dtype(np.uint8), # Keys to GSO 65526: np.dtype(np.float64), @@ -977,7 +974,8 @@ def __init__(self): } # error: Argument 1 to "list" has incompatible type "str"; # expected "Iterable[int]" [arg-type] - self.TYPE_MAP = list(range(251)) + list("bhlfd") # type: ignore[arg-type] + self.TYPE_MAP = list(range(251)) + list( + "bhlfd") # type: ignore[arg-type] self.TYPE_MAP_XML = { # Not really a Q, unclear how to handle byteswap 32768: "Q", @@ -1019,13 +1017,17 @@ def __init__(self): # These missing values are the generic '.' in Stata, and are used # to replace nans self.MISSING_VALUES = { - "b": 101, - "h": 32741, - "l": 2147483621, - "f": np.float32(struct.unpack(" None: else: self._read_old_header(first_char) - self.has_string_data = len([x for x in self.typlist if type(x) is int]) > 0 + self.has_string_data = len([x for x in self.typlist if type(x) is int + ]) > 0 # calculate size of a data record self.col_sizes = [self._calcsize(typ) for typ in self.typlist] @@ -1200,16 +1203,16 @@ def _read_new_header(self) -> None: self.path_or_buf.read(27) # stata_dta>
self.format_version = int(self.path_or_buf.read(3)) if self.format_version not in [117, 118, 119]: - raise ValueError(_version_error.format(version=self.format_version)) + raise ValueError( + _version_error.format(version=self.format_version)) self._set_encoding() self.path_or_buf.read(21) # self.byteorder = self.path_or_buf.read(3) == b"MSF" and ">" or "<" self.path_or_buf.read(15) # nvar_type = "H" if self.format_version <= 118 else "I" nvar_size = 2 if self.format_version <= 118 else 4 - self.nvar = struct.unpack( - self.byteorder + nvar_type, self.path_or_buf.read(nvar_size) - )[0] + self.nvar = struct.unpack(self.byteorder + nvar_type, + self.path_or_buf.read(nvar_size))[0] self.path_or_buf.read(7) # self.nobs = self._get_nobs() @@ -1222,34 +1225,34 @@ def _read_new_header(self) -> None: self.path_or_buf.read(8) # position of self._seek_vartypes = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 16 - ) + struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + + 16) self._seek_varnames = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10 - ) + struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + + 10) self._seek_sortlist = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10 - ) + struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + + 10) self._seek_formats = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 9 - ) + struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + + 9) self._seek_value_label_names = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 19 - ) + struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + + 19) # Requires version-specific treatment self._seek_variable_labels = self._get_seek_variable_labels() self.path_or_buf.read(8) # self.data_location = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 6 - ) + struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + + 6) self.seek_strls = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 7 - ) + struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + + 7) self.seek_value_labels = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 14 - ) + struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + + 14) self.typlist, self.dtyplist = self._get_dtypes(self._seek_vartypes) @@ -1273,7 +1276,7 @@ def _read_new_header(self) -> None: # Get data type information, works for versions 117-119. def _get_dtypes( - self, seek_vartypes: int + self, seek_vartypes: int ) -> tuple[list[int | str], list[str | np.dtype]]: self.path_or_buf.seek(seek_vartypes) @@ -1288,7 +1291,8 @@ def f(typ: int) -> int | str: try: return self.TYPE_MAP_XML[typ] except KeyError as err: - raise ValueError(f"cannot convert stata types [{typ}]") from err + raise ValueError( + f"cannot convert stata types [{typ}]") from err typlist = [f(x) for x in raw_typlist] @@ -1300,7 +1304,8 @@ def g(typ: int) -> str | np.dtype: # "Union[str, dtype]") return self.DTYPE_MAP_XML[typ] # type: ignore[return-value] except KeyError as err: - raise ValueError(f"cannot convert stata dtype [{typ}]") from err + raise ValueError( + f"cannot convert stata dtype [{typ}]") from err dtyplist = [g(x) for x in raw_typlist] @@ -1309,7 +1314,9 @@ def g(typ: int) -> str | np.dtype: def _get_varlist(self) -> list[str]: # 33 in order formats, 129 in formats 118 and 119 b = 33 if self.format_version < 118 else 129 - return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)] + return [ + self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar) + ] # Returns the format list def _get_fmtlist(self) -> list[str]: @@ -1322,7 +1329,9 @@ def _get_fmtlist(self) -> list[str]: else: b = 7 - return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)] + return [ + self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar) + ] # Returns the label list def _get_lbllist(self) -> list[str]: @@ -1332,32 +1341,40 @@ def _get_lbllist(self) -> list[str]: b = 33 else: b = 9 - return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)] + return [ + self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar) + ] def _get_variable_labels(self) -> list[str]: if self.format_version >= 118: vlblist = [ - self._decode(self.path_or_buf.read(321)) for _ in range(self.nvar) + self._decode(self.path_or_buf.read(321)) + for _ in range(self.nvar) ] elif self.format_version > 105: vlblist = [ - self._decode(self.path_or_buf.read(81)) for _ in range(self.nvar) + self._decode(self.path_or_buf.read(81)) + for _ in range(self.nvar) ] else: vlblist = [ - self._decode(self.path_or_buf.read(32)) for _ in range(self.nvar) + self._decode(self.path_or_buf.read(32)) + for _ in range(self.nvar) ] return vlblist def _get_nobs(self) -> int: if self.format_version >= 118: - return struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0] + return struct.unpack(self.byteorder + "Q", + self.path_or_buf.read(8))[0] else: - return struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] + return struct.unpack(self.byteorder + "I", + self.path_or_buf.read(4))[0] def _get_data_label(self) -> str: if self.format_version >= 118: - strlen = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0] + strlen = struct.unpack(self.byteorder + "H", + self.path_or_buf.read(2))[0] return self._decode(self.path_or_buf.read(strlen)) elif self.format_version == 117: strlen = struct.unpack("b", self.path_or_buf.read(1))[0] @@ -1387,22 +1404,25 @@ def _get_seek_variable_labels(self) -> int: # variable, 20 for the closing tag and 17 for the opening tag return self._seek_value_label_names + (33 * self.nvar) + 20 + 17 elif self.format_version >= 118: - return struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 17 + return struct.unpack(self.byteorder + "q", + self.path_or_buf.read(8))[0] + 17 else: raise ValueError() def _read_old_header(self, first_char: bytes) -> None: self.format_version = struct.unpack("b", first_char)[0] if self.format_version not in [104, 105, 108, 111, 113, 114, 115]: - raise ValueError(_version_error.format(version=self.format_version)) + raise ValueError( + _version_error.format(version=self.format_version)) self._set_encoding() self.byteorder = ( - struct.unpack("b", self.path_or_buf.read(1))[0] == 0x1 and ">" or "<" - ) + struct.unpack("b", self.path_or_buf.read(1))[0] == 0x1 and ">" + or "<") self.filetype = struct.unpack("b", self.path_or_buf.read(1))[0] self.path_or_buf.read(1) # unused - self.nvar = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0] + self.nvar = struct.unpack(self.byteorder + "H", + self.path_or_buf.read(2))[0] self.nobs = self._get_nobs() self._data_label = self._get_data_label() @@ -1426,20 +1446,24 @@ def _read_old_header(self, first_char: bytes) -> None: self.typlist = [self.TYPE_MAP[typ] for typ in typlist] except ValueError as err: invalid_types = ",".join([str(x) for x in typlist]) - raise ValueError(f"cannot convert stata types [{invalid_types}]") from err + raise ValueError( + f"cannot convert stata types [{invalid_types}]") from err try: self.dtyplist = [self.DTYPE_MAP[typ] for typ in typlist] except ValueError as err: invalid_dtypes = ",".join([str(x) for x in typlist]) - raise ValueError(f"cannot convert stata dtypes [{invalid_dtypes}]") from err + raise ValueError( + f"cannot convert stata dtypes [{invalid_dtypes}]") from err if self.format_version > 108: self.varlist = [ - self._decode(self.path_or_buf.read(33)) for _ in range(self.nvar) + self._decode(self.path_or_buf.read(33)) + for _ in range(self.nvar) ] else: self.varlist = [ - self._decode(self.path_or_buf.read(9)) for _ in range(self.nvar) + self._decode(self.path_or_buf.read(9)) + for _ in range(self.nvar) ] self.srtlist = struct.unpack( self.byteorder + ("h" * (self.nvar + 1)), @@ -1459,17 +1483,14 @@ def _read_old_header(self, first_char: bytes) -> None: if self.format_version > 104: while True: - data_type = struct.unpack( - self.byteorder + "b", self.path_or_buf.read(1) - )[0] + data_type = struct.unpack(self.byteorder + "b", + self.path_or_buf.read(1))[0] if self.format_version > 108: - data_len = struct.unpack( - self.byteorder + "i", self.path_or_buf.read(4) - )[0] + data_len = struct.unpack(self.byteorder + "i", + self.path_or_buf.read(4))[0] else: - data_len = struct.unpack( - self.byteorder + "h", self.path_or_buf.read(2) - )[0] + data_len = struct.unpack(self.byteorder + "h", + self.path_or_buf.read(2))[0] if data_type == 0: break self.path_or_buf.read(data_len) @@ -1486,7 +1507,8 @@ def _setup_dtype(self) -> np.dtype: for i, typ in enumerate(self.typlist): if typ in self.NUMPY_TYPE_MAP: typ = cast(str, typ) # only strs in NUMPY_TYPE_MAP - dtypes.append(("s" + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ])) + dtypes.append( + ("s" + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ])) else: dtypes.append(("s" + str(i), "S" + str(typ))) self._dtype = np.dtype(dtypes) @@ -1549,14 +1571,16 @@ def _read_value_labels(self) -> None: labname = self._decode(self.path_or_buf.read(129)) self.path_or_buf.read(3) # padding - n = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] - txtlen = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] - off = np.frombuffer( - self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n - ) - val = np.frombuffer( - self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n - ) + n = struct.unpack(self.byteorder + "I", + self.path_or_buf.read(4))[0] + txtlen = struct.unpack(self.byteorder + "I", + self.path_or_buf.read(4))[0] + off = np.frombuffer(self.path_or_buf.read(4 * n), + dtype=self.byteorder + "i4", + count=n) + val = np.frombuffer(self.path_or_buf.read(4 * n), + dtype=self.byteorder + "i4", + count=n) ii = np.argsort(off) off = off[ii] val = val[ii] @@ -1564,7 +1588,8 @@ def _read_value_labels(self) -> None: self.value_label_dict[labname] = {} for i in range(n): end = off[i + 1] if i < n - 1 else txtlen - self.value_label_dict[labname][val[i]] = self._decode(txt[off[i] : end]) + self.value_label_dict[labname][val[i]] = self._decode( + txt[off[i]:end]) if self.format_version >= 117: self.path_or_buf.read(6) # self._value_labels_read = True @@ -1578,19 +1603,21 @@ def _read_strls(self) -> None: break if self.format_version == 117: - v_o = struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0] + v_o = struct.unpack(self.byteorder + "Q", + self.path_or_buf.read(8))[0] else: buf = self.path_or_buf.read(12) # Only tested on little endian file on little endian machine. v_size = 2 if self.format_version == 118 else 3 if self.byteorder == "<": - buf = buf[0:v_size] + buf[4 : (12 - v_size)] + buf = buf[0:v_size] + buf[4:(12 - v_size)] else: # This path may not be correct, impossible to test - buf = buf[0:v_size] + buf[(4 + v_size) :] + buf = buf[0:v_size] + buf[(4 + v_size):] v_o = struct.unpack("Q", buf)[0] typ = struct.unpack("B", self.path_or_buf.read(1))[0] - length = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] + length = struct.unpack(self.byteorder + "I", + self.path_or_buf.read(4))[0] va = self.path_or_buf.read(length) if typ == 130: decoded_va = va[0:-1].decode(self._encoding) @@ -1681,9 +1708,9 @@ def read( offset = self._lines_read * dtype.itemsize self.path_or_buf.seek(self.data_location + offset) read_lines = min(nrows, self.nobs - self._lines_read) - raw_data = np.frombuffer( - self.path_or_buf.read(read_len), dtype=dtype, count=read_lines - ) + raw_data = np.frombuffer(self.path_or_buf.read(read_len), + dtype=dtype, + count=read_lines) self._lines_read += read_lines if self._lines_read == self.nobs: @@ -1706,7 +1733,8 @@ def read( # restarting at 0 for each chunk. if index_col is None: rng = np.arange(self._lines_read - read_lines, self._lines_read) - data.index = Index(rng) # set attr instead of set_index to avoid copy + data.index = Index( + rng) # set attr instead of set_index to avoid copy if columns is not None: try: @@ -1734,8 +1762,7 @@ def read( if dtype != np.dtype(object) and dtype != self.dtyplist[i]: requires_type_conversion = True data_formatted.append( - (col, Series(data[col], ix, self.dtyplist[i])) - ) + (col, Series(data[col], ix, self.dtyplist[i]))) else: data_formatted.append((col, data[col])) if requires_type_conversion: @@ -1754,16 +1781,15 @@ def any_startswith(x: str) -> bool: col = data.columns[i] try: data[col] = _stata_elapsed_date_to_datetime_vec( - data[col], self.fmtlist[i] - ) + data[col], self.fmtlist[i]) except ValueError: self.close() raise if convert_categoricals and self.format_version > 108: - data = self._do_convert_categoricals( - data, self.value_label_dict, self.lbllist, order_categoricals - ) + data = self._do_convert_categoricals(data, self.value_label_dict, + self.lbllist, + order_categoricals) if not preserve_dtypes: retyped_data = [] @@ -1774,9 +1800,9 @@ def any_startswith(x: str) -> bool: dtype = np.dtype(np.float64) convert = True elif dtype in ( - np.dtype(np.int8), - np.dtype(np.int16), - np.dtype(np.int32), + np.dtype(np.int8), + np.dtype(np.int16), + np.dtype(np.int32), ): dtype = np.dtype(np.int64) convert = True @@ -1789,7 +1815,8 @@ def any_startswith(x: str) -> bool: return data - def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFrame: + def _do_convert_missing(self, data: DataFrame, + convert_missing: bool) -> DataFrame: # Check for missing values, and replace if found replacements = {} for i, colname in enumerate(data): @@ -1810,7 +1837,8 @@ def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFra if convert_missing: # Replacement follows Stata notation missing_loc = np.nonzero(np.asarray(missing))[0] - umissing, umissing_loc = np.unique(series[missing], return_inverse=True) + umissing, umissing_loc = np.unique(series[missing], + return_inverse=True) replacement = Series(series, dtype=object) for j, um in enumerate(umissing): missing_value = StataMissingValue(um) @@ -1846,7 +1874,8 @@ def _insert_strls(self, data: DataFrame) -> DataFrame: data.iloc[:, i] = [self.GSO[str(k)] for k in data.iloc[:, i]] return data - def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFrame: + def _do_select_columns(self, data: DataFrame, + columns: Sequence[str]) -> DataFrame: if not self._column_selector_set: column_set = set(columns) @@ -1855,10 +1884,8 @@ def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFra unmatched = column_set.difference(data.columns) if unmatched: joined = ", ".join(list(unmatched)) - raise ValueError( - "The following columns were not " - f"found in the Stata data set: {joined}" - ) + raise ValueError("The following columns were not " + f"found in the Stata data set: {joined}") # Copy information for retained columns for later processing dtyplist = [] typlist = [] @@ -1907,13 +1934,12 @@ def _do_convert_categoricals( else: if self._using_iterator: # warn is using an iterator - warnings.warn( - categorical_conversion_warning, CategoricalConversionWarning - ) + warnings.warn(categorical_conversion_warning, + CategoricalConversionWarning) initial_categories = None - cat_data = Categorical( - column, categories=initial_categories, ordered=order_categoricals - ) + cat_data = Categorical(column, + categories=initial_categories, + ordered=order_categoricals) if initial_categories is None: # If None here, then we need to match the cats in the Categorical categories = [] @@ -2046,27 +2072,28 @@ def _convert_datetime_to_stata_type(fmt: str) -> np.dtype: Convert from one of the stata date formats to a type in TYPE_MAP. """ if fmt in [ - "tc", - "%tc", - "td", - "%td", - "tw", - "%tw", - "tm", - "%tm", - "tq", - "%tq", - "th", - "%th", - "ty", - "%ty", + "tc", + "%tc", + "td", + "%td", + "tw", + "%tw", + "tm", + "%tm", + "tq", + "%tq", + "th", + "%th", + "ty", + "%ty", ]: return np.dtype(np.float64) # Stata expects doubles for SIFs else: raise NotImplementedError(f"Format {fmt} not implemented") -def _maybe_convert_to_int_keys(convert_dates: dict, varlist: list[Hashable]) -> dict: +def _maybe_convert_to_int_keys(convert_dates: dict, + varlist: list[Hashable]) -> dict: new_dict = {} for key in convert_dates: if not convert_dates[key].startswith("%"): # make sure proper fmts @@ -2075,7 +2102,8 @@ def _maybe_convert_to_int_keys(convert_dates: dict, varlist: list[Hashable]) -> new_dict.update({varlist.index(key): convert_dates[key]}) else: if not isinstance(key, int): - raise ValueError("convert_dates key must be a column or an integer") + raise ValueError( + "convert_dates key must be a column or an integer") new_dict.update({key: convert_dates[key]}) return new_dict @@ -2116,9 +2144,10 @@ def _dtype_to_stata_type(dtype: np.dtype, column: Series) -> int: raise NotImplementedError(f"Data type {dtype} not supported.") -def _dtype_to_default_stata_fmt( - dtype, column: Series, dta_version: int = 114, force_strl: bool = False -) -> str: +def _dtype_to_default_stata_fmt(dtype, + column: Series, + dta_version: int = 114, + force_strl: bool = False) -> str: """ Map numpy dtype to stata's default format for this type. Not terribly important since users can change this in Stata. Semantics are @@ -2147,7 +2176,8 @@ def _dtype_to_default_stata_fmt( if dta_version >= 117: return "%9s" else: - raise ValueError(excessive_string_length_error.format(column.name)) + raise ValueError( + excessive_string_length_error.format(column.name)) return "%" + str(max(itemsize, 1)) + "s" elif dtype == np.float64: return "%10.0g" @@ -2302,8 +2332,7 @@ def _write_bytes(self, value: bytes) -> None: self.handles.handle.write(value) def _prepare_non_cat_value_labels( - self, data: DataFrame - ) -> list[StataNonCatValueLabel]: + self, data: DataFrame) -> list[StataNonCatValueLabel]: """ Check for value labels provided for non-categorical columns. Value labels @@ -2320,16 +2349,14 @@ def _prepare_non_cat_value_labels( else: raise KeyError( f"Can't create value labels for {labname}, it wasn't " - "found in the dataset." - ) + "found in the dataset.") if not is_numeric_dtype(data[colname].dtype): # Labels should not be passed explicitly for categorical # columns that will be converted to int raise ValueError( f"Can't create value labels for {labname}, value labels " - "can only be applied to numeric columns." - ) + "can only be applied to numeric columns.") svl = StataNonCatValueLabel(colname, labels) non_cat_value_labels.append(svl) return non_cat_value_labels @@ -2353,10 +2380,8 @@ def _prepare_categoricals(self, data: DataFrame) -> DataFrame: self._value_labels.append(svl) dtype = data[col].cat.codes.dtype if dtype == np.int64: - raise ValueError( - "It is not possible to export " - "int64-based categorical data to Stata." - ) + raise ValueError("It is not possible to export " + "int64-based categorical data to Stata.") values = data[col].cat.codes._values.copy() # Upcast if needed so that correct missing values can be set @@ -2418,12 +2443,8 @@ def _validate_variable_name(self, name: str) -> str: and _. """ for c in name: - if ( - (c < "A" or c > "Z") - and (c < "a" or c > "z") - and (c < "0" or c > "9") - and c != "_" - ): + if ((c < "A" or c > "Z") and (c < "a" or c > "z") + and (c < "0" or c > "9") and c != "_"): name = name.replace(c, "_") return name @@ -2460,14 +2481,14 @@ def _check_column_names(self, data: DataFrame) -> DataFrame: if "0" <= name[0] <= "9": name = "_" + name - name = name[: min(len(name), 32)] + name = name[:min(len(name), 32)] if not name == orig_name: # check for duplicates while columns.count(name) > 0: # prepend ascending number to avoid duplicates name = "_" + str(duplicate_var_id) + name - name = name[: min(len(name), 32)] + name = name[:min(len(name), 32)] duplicate_var_id += 1 converted_names[orig_name] = name @@ -2500,7 +2521,8 @@ def _set_formats_and_types(self, dtypes: Series) -> None: self.fmtlist: list[str] = [] self.typlist: list[int] = [] for col, dtype in dtypes.items(): - self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, self.data[col])) + self.fmtlist.append( + _dtype_to_default_stata_fmt(dtype, self.data[col])) self.typlist.append(_dtype_to_stata_type(dtype, self.data[col])) def _prepare_pandas(self, data: DataFrame) -> None: @@ -2552,10 +2574,10 @@ def _prepare_pandas(self, data: DataFrame) -> None: self._convert_dates[col] = "tc" self._convert_dates = _maybe_convert_to_int_keys( - self._convert_dates, self.varlist - ) + self._convert_dates, self.varlist) for key in self._convert_dates: - new_type = _convert_datetime_to_stata_type(self._convert_dates[key]) + new_type = _convert_datetime_to_stata_type( + self._convert_dates[key]) dtypes[key] = np.dtype(new_type) # Verify object arrays are strings and encode to bytes @@ -2590,20 +2612,16 @@ def _encode_strings(self) -> None: inferred_dtype = infer_dtype(column, skipna=True) if not ((inferred_dtype == "string") or len(column) == 0): col = column.name - raise ValueError( - f"""\ + raise ValueError(f"""\ Column `{col}` cannot be exported.\n\nOnly string-like object arrays containing all strings or a mix of strings and None can be exported. Object arrays containing only null values are prohibited. Other object types cannot be exported and must first be converted to one of the -supported types.""" - ) +supported types.""") encoded = self.data[col].str.encode(self._encoding) # If larger than _max_string_length do nothing - if ( - max_len_string_array(ensure_object(encoded._values)) - <= self._max_string_length - ): + if (max_len_string_array(ensure_object(encoded._values)) <= + self._max_string_length): self.data[col] = encoded def write_file(self) -> None: @@ -2611,23 +2629,23 @@ def write_file(self) -> None: Export DataFrame object to Stata dta format. """ with get_handle( - self._fname, - "wb", - compression=self._compression, - is_text=False, - storage_options=self.storage_options, + self._fname, + "wb", + compression=self._compression, + is_text=False, + storage_options=self.storage_options, ) as self.handles: if self.handles.compression["method"] is not None: # ZipFile creates a file (with the same name) for each write call. # Write it first into a buffer and then write the buffer to the ZipFile. - self._output_file, self.handles.handle = self.handles.handle, BytesIO() + self._output_file, self.handles.handle = self.handles.handle, BytesIO( + ) self.handles.created_handles.append(self.handles.handle) try: - self._write_header( - data_label=self._data_label, time_stamp=self._time_stamp - ) + self._write_header(data_label=self._data_label, + time_stamp=self._time_stamp) self._write_map() self._write_variable_types() self._write_varnames() @@ -2646,9 +2664,9 @@ def write_file(self) -> None: self._close() except Exception as exc: self.handles.close() - if isinstance(self._fname, (str, os.PathLike)) and os.path.isfile( - self._fname - ): + if isinstance(self._fname, + (str, os.PathLike)) and os.path.isfile( + self._fname): try: os.unlink(self._fname) except OSError: @@ -2719,8 +2737,7 @@ def _write_header( self._write_bytes(self._null_terminate_bytes(_pad_bytes("", 80))) else: self._write_bytes( - self._null_terminate_bytes(_pad_bytes(data_label[:80], 80)) - ) + self._null_terminate_bytes(_pad_bytes(data_label[:80], 80))) # time stamp, 18 bytes, char, null terminated # format dd Mon yyyy hh:mm if time_stamp is None: @@ -2744,11 +2761,8 @@ def _write_header( "Dec", ] month_lookup = {i + 1: month for i, month in enumerate(months)} - ts = ( - time_stamp.strftime("%d ") - + month_lookup[time_stamp.month] - + time_stamp.strftime(" %Y %H:%M") - ) + ts = (time_stamp.strftime("%d ") + month_lookup[time_stamp.month] + + time_stamp.strftime(" %Y %H:%M")) self._write_bytes(self._null_terminate_bytes(ts)) def _write_variable_types(self) -> None: @@ -2798,13 +2812,13 @@ def _write_variable_labels(self) -> None: if col in self._variable_labels: label = self._variable_labels[col] if len(label) > 80: - raise ValueError("Variable labels must be 80 characters or fewer") + raise ValueError( + "Variable labels must be 80 characters or fewer") is_latin1 = all(ord(c) < 256 for c in label) if not is_latin1: raise ValueError( "Variable labels must contain only characters that " - "can be encoded in Latin-1" - ) + "can be encoded in Latin-1") self._write(_pad_bytes(label, 81)) else: self._write(blank) @@ -2822,8 +2836,7 @@ def _prepare_data(self) -> np.recarray: for i, col in enumerate(data): if i in convert_dates: data[col] = _datetime_to_stata_elapsed_vec( - data[col], self.fmtlist[i] - ) + data[col], self.fmtlist[i]) # 2. Convert strls data = self._convert_strls(data) @@ -2833,7 +2846,8 @@ def _prepare_data(self) -> np.recarray: for i, col in enumerate(data): typ = typlist[i] if typ <= self._max_string_length: - data[col] = data[col].fillna("").apply(_pad_bytes, args=(typ,)) + data[col] = data[col].fillna("").apply(_pad_bytes, + args=(typ, )) stype = f"S{typ}" dtypes[col] = stype data[col] = data[col].astype(stype) @@ -2857,7 +2871,8 @@ def _null_terminate_bytes(self, s: str) -> bytes: return self._null_terminate_str(s).encode(self._encoding) -def _dtype_to_stata_type_117(dtype: np.dtype, column: Series, force_strl: bool) -> int: +def _dtype_to_stata_type_117(dtype: np.dtype, column: Series, + force_strl: bool) -> int: """ Converts dtype types to stata types. Returns the byte of the given ordinal. See TYPE_MAP and comments for an explanation. This is also explained in @@ -2966,7 +2981,7 @@ def __init__( o_size = 6 else: # version == 119 o_size = 5 - self._o_offet = 2 ** (8 * (8 - o_size)) + self._o_offet = 2**(8 * (8 - o_size)) self._gso_o_type = gso_o_type self._gso_v_type = gso_v_type @@ -3222,7 +3237,8 @@ def _tag(val: str | bytes, tag: str) -> bytes: """Surround val with """ if isinstance(val, str): val = bytes(val, "utf-8") - return bytes("<" + tag + ">", "utf-8") + val + bytes("", "utf-8") + return bytes("<" + tag + ">", "utf-8") + val + bytes( + "", "utf-8") def _update_map(self, tag: str) -> None: """Update map location for tag with file position""" @@ -3244,10 +3260,12 @@ def _write_header( bio.write(self._tag(byteorder == ">" and "MSF" or "LSF", "byteorder")) # number of vars, 2 bytes in 117 and 118, 4 byte in 119 nvar_type = "H" if self._dta_version <= 118 else "I" - bio.write(self._tag(struct.pack(byteorder + nvar_type, self.nvar), "K")) + bio.write(self._tag(struct.pack(byteorder + nvar_type, self.nvar), + "K")) # 117 uses 4 bytes, 118 uses 8 nobs_size = "I" if self._dta_version == 117 else "Q" - bio.write(self._tag(struct.pack(byteorder + nobs_size, self.nobs), "N")) + bio.write(self._tag(struct.pack(byteorder + nobs_size, self.nobs), + "N")) # data label 81 bytes, char, null terminated label = data_label[:80] if data_label is not None else "" encoded_label = label.encode(self._encoding) @@ -3277,11 +3295,8 @@ def _write_header( "Dec", ] month_lookup = {i + 1: month for i, month in enumerate(months)} - ts = ( - time_stamp.strftime("%d ") - + month_lookup[time_stamp.month] - + time_stamp.strftime(" %Y %H:%M") - ) + ts = (time_stamp.strftime("%d ") + month_lookup[time_stamp.month] + + time_stamp.strftime(" %Y %H:%M")) # '\x11' added due to inspection of Stata file stata_ts = b"\x11" + bytes(ts, "utf-8") bio.write(self._tag(stata_ts, "timestamp")) @@ -3338,7 +3353,8 @@ def _write_varnames(self) -> None: def _write_sortlist(self) -> None: self._update_map("sortlist") sort_size = 2 if self._dta_version < 119 else 4 - self._write_bytes(self._tag(b"\x00" * sort_size * (self.nvar + 1), "sortlist")) + self._write_bytes( + self._tag(b"\x00" * sort_size * (self.nvar + 1), "sortlist")) def _write_formats(self) -> None: self._update_map("formats") @@ -3359,7 +3375,8 @@ def _write_value_label_names(self) -> None: if self._has_value_labels[i]: name = self.varlist[i] name = self._null_terminate_str(name) - encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1) + encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), + vl_len + 1) bio.write(encoded_name) self._write_bytes(self._tag(bio.getvalue(), "value_label_names")) @@ -3381,14 +3398,14 @@ def _write_variable_labels(self) -> None: if col in self._variable_labels: label = self._variable_labels[col] if len(label) > 80: - raise ValueError("Variable labels must be 80 characters or fewer") + raise ValueError( + "Variable labels must be 80 characters or fewer") try: encoded = label.encode(self._encoding) except UnicodeEncodeError as err: raise ValueError( "Variable labels must contain only characters that " - f"can be encoded in {self._encoding}" - ) from err + f"can be encoded in {self._encoding}") from err bio.write(_pad_bytes_new(encoded, vl_len + 1)) else: @@ -3444,13 +3461,14 @@ def _convert_strls(self, data: DataFrame) -> DataFrame: convert_strl variable """ convert_cols = [ - col - for i, col in enumerate(data) + col for i, col in enumerate(data) if self.typlist[i] == 32768 or col in self._convert_strl ] if convert_cols: - ssw = StataStrLWriter(data, convert_cols, version=self._dta_version) + ssw = StataStrLWriter(data, + convert_cols, + version=self._dta_version) tab, new_data = ssw.generate_table() data = new_data self._strl_blob = ssw.generate_blob(tab) @@ -3469,8 +3487,7 @@ def _set_formats_and_types(self, dtypes: Series) -> None: ) self.fmtlist.append(fmt) self.typlist.append( - _dtype_to_stata_type_117(dtype, self.data[col], force_strl) - ) + _dtype_to_stata_type_117(dtype, self.data[col], force_strl)) class StataWriterUTF8(StataWriter117): @@ -3600,8 +3617,7 @@ def __init__( elif version == 118 and data.shape[1] > 32767: raise ValueError( "You must use version 119 for data sets containing more than" - "32,767 variables" - ) + "32,767 variables") super().__init__( fname, @@ -3642,13 +3658,9 @@ def _validate_variable_name(self, name: str) -> str: """ # High code points appear to be acceptable for c in name: - if ( - ord(c) < 128 - and (c < "A" or c > "Z") - and (c < "a" or c > "z") - and (c < "0" or c > "9") - and c != "_" - ) or 128 <= ord(c) < 256: + if (ord(c) < 128 and (c < "A" or c > "Z") and + (c < "a" or c > "z") and + (c < "0" or c > "9") and c != "_") or 128 <= ord(c) < 256: name = name.replace(c, "_") return name diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index 3726ca6682a46..fc3c3e8bb6829 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -56,7 +56,8 @@ def test_apply_axis1_with_ea(): @pytest.mark.parametrize( "data, dtype", - [(1, None), (1, CategoricalDtype([1])), (Timestamp("2013-01-01", tz="UTC"), None)], + [(1, None), (1, CategoricalDtype([1])), + (Timestamp("2013-01-01", tz="UTC"), None)], ) def test_agg_axis1_duplicate_index(data, dtype): # GH 42380 @@ -68,12 +69,10 @@ def test_agg_axis1_duplicate_index(data, dtype): def test_apply_mixed_datetimelike(): # mixed datetimelike # GH 7778 - expected = DataFrame( - { - "A": date_range("20130101", periods=3), - "B": pd.to_timedelta(np.arange(3), unit="s"), - } - ) + expected = DataFrame({ + "A": date_range("20130101", periods=3), + "B": pd.to_timedelta(np.arange(3), unit="s"), + }) result = expected.apply(lambda x: x, axis=1) tm.assert_frame_equal(result, expected) @@ -194,12 +193,12 @@ def test_apply_broadcast_lists_columns(float_frame): def test_apply_broadcast_lists_index(float_frame): - result = float_frame.apply( - lambda x: list(range(len(float_frame.index))), result_type="broadcast" - ) + result = float_frame.apply(lambda x: list(range(len(float_frame.index))), + result_type="broadcast") m = list(range(len(float_frame.index))) expected = DataFrame( - {c: m for c in float_frame.columns}, + {c: m + for c in float_frame.columns}, dtype="float64", index=float_frame.index, ) @@ -226,6 +225,7 @@ def test_apply_broadcast_series_lambda_func(int_frame_const_col): @pytest.mark.parametrize("axis", [0, 1]) def test_apply_raw_float_frame(float_frame, axis): + def _assert_raw(x): assert isinstance(x, np.ndarray) assert x.ndim == 1 @@ -249,6 +249,7 @@ def test_apply_raw_float_frame_no_reduction(float_frame): @pytest.mark.parametrize("axis", [0, 1]) def test_apply_raw_mixed_type_frame(mixed_type_frame, axis): + def _assert_raw(x): assert isinstance(x, np.ndarray) assert x.ndim == 1 @@ -285,9 +286,8 @@ def test_apply_mixed_dtype_corner_indexing(): @pytest.mark.parametrize("ax", ["index", "columns"]) -@pytest.mark.parametrize( - "func", [lambda x: x, lambda x: x.mean()], ids=["identity", "mean"] -) +@pytest.mark.parametrize("func", [lambda x: x, lambda x: x.mean()], + ids=["identity", "mean"]) @pytest.mark.parametrize("raw", [True, False]) @pytest.mark.parametrize("axis", [0, 1]) def test_apply_empty_infer_type(ax, func, raw, axis): @@ -315,6 +315,7 @@ def test_apply_empty_infer_type_broadcast(): def test_apply_with_args_kwds_add_some(float_frame): + def add_some(x, howmuch=0): return x + howmuch @@ -324,6 +325,7 @@ def add_some(x, howmuch=0): def test_apply_with_args_kwds_agg_and_add(float_frame): + def agg_and_add(x, howmuch=0): return x.mean() + howmuch @@ -333,10 +335,11 @@ def agg_and_add(x, howmuch=0): def test_apply_with_args_kwds_subtract_and_divide(float_frame): + def subtract_and_divide(x, sub, divide=1): return (x - sub) / divide - result = float_frame.apply(subtract_and_divide, args=(2,), divide=2) + result = float_frame.apply(subtract_and_divide, args=(2, ), divide=2) expected = float_frame.apply(lambda x: (x - 2.0) / 2.0) tm.assert_frame_equal(result, expected) @@ -355,14 +358,30 @@ def test_apply_reduce_Series(float_frame): def test_apply_reduce_to_dict(): # GH 25196 37544 - data = DataFrame([[1, 2], [3, 4]], columns=["c0", "c1"], index=["i0", "i1"]) + data = DataFrame([[1, 2], [3, 4]], + columns=["c0", "c1"], + index=["i0", "i1"]) result = data.apply(dict, axis=0) - expected = Series([{"i0": 1, "i1": 3}, {"i0": 2, "i1": 4}], index=data.columns) + expected = Series([{ + "i0": 1, + "i1": 3 + }, { + "i0": 2, + "i1": 4 + }], + index=data.columns) tm.assert_series_equal(result, expected) result = data.apply(dict, axis=1) - expected = Series([{"c0": 1, "c1": 2}, {"c0": 3, "c1": 4}], index=data.index) + expected = Series([{ + "c0": 1, + "c1": 2 + }, { + "c0": 3, + "c1": 4 + }], + index=data.index) tm.assert_series_equal(result, expected) @@ -370,11 +389,15 @@ def test_apply_differently_indexed(): df = DataFrame(np.random.randn(20, 10)) result = df.apply(Series.describe, axis=0) - expected = DataFrame({i: v.describe() for i, v in df.items()}, columns=df.columns) + expected = DataFrame({i: v.describe() + for i, v in df.items()}, + columns=df.columns) tm.assert_frame_equal(result, expected) result = df.apply(Series.describe, axis=1) - expected = DataFrame({i: v.describe() for i, v in df.T.items()}, columns=df.index).T + expected = DataFrame({i: v.describe() + for i, v in df.T.items()}, + columns=df.index).T tm.assert_frame_equal(result, expected) @@ -414,52 +437,53 @@ def f(r): def test_apply_convert_objects(): - expected = DataFrame( - { - "A": [ - "foo", - "foo", - "foo", - "foo", - "bar", - "bar", - "bar", - "bar", - "foo", - "foo", - "foo", - ], - "B": [ - "one", - "one", - "one", - "two", - "one", - "one", - "one", - "two", - "two", - "two", - "one", - ], - "C": [ - "dull", - "dull", - "shiny", - "dull", - "dull", - "shiny", - "shiny", - "dull", - "shiny", - "shiny", - "shiny", - ], - "D": np.random.randn(11), - "E": np.random.randn(11), - "F": np.random.randn(11), - } - ) + expected = DataFrame({ + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + "D": + np.random.randn(11), + "E": + np.random.randn(11), + "F": + np.random.randn(11), + }) result = expected.apply(lambda x: x, axis=1)._convert(datetime=True) tm.assert_frame_equal(result, expected) @@ -491,17 +515,21 @@ def test_apply_attach_name_non_reduction(float_frame): def test_apply_attach_name_non_reduction_axis1(float_frame): result = float_frame.apply(lambda x: np.repeat(x.name, len(x)), axis=1) expected = Series( - np.repeat(t[0], len(float_frame.columns)) for t in float_frame.itertuples() - ) + np.repeat(t[0], len(float_frame.columns)) + for t in float_frame.itertuples()) expected.index = float_frame.index tm.assert_series_equal(result, expected) def test_apply_multi_index(): index = MultiIndex.from_arrays([["a", "a", "b"], ["c", "d", "d"]]) - s = DataFrame([[1, 2], [3, 4], [5, 6]], index=index, columns=["col1", "col2"]) + s = DataFrame([[1, 2], [3, 4], [5, 6]], + index=index, + columns=["col1", "col2"]) result = s.apply(lambda x: Series({"min": min(x), "max": max(x)}), 1) - expected = DataFrame([[1, 2], [3, 4], [5, 6]], index=index, columns=["min", "max"]) + expected = DataFrame([[1, 2], [3, 4], [5, 6]], + index=index, + columns=["min", "max"]) tm.assert_frame_equal(result, expected, check_like=True) @@ -510,9 +538,22 @@ def test_apply_multi_index(): [ [ DataFrame([["foo", "bar"], ["spam", "eggs"]]), - Series([{0: "foo", 1: "spam"}, {0: "bar", 1: "eggs"}]), + Series([{ + 0: "foo", + 1: "spam" + }, { + 0: "bar", + 1: "eggs" + }]), ], - [DataFrame([[0, 1], [2, 3]]), Series([{0: 0, 1: 2}, {0: 1, 1: 3}])], + [DataFrame([[0, 1], [2, 3]]), + Series([{ + 0: 0, + 1: 2 + }, { + 0: 1, + 1: 3 + }])], ], ) def test_apply_dict(df, dicts): @@ -560,7 +601,8 @@ def test_applymap_str(): @pytest.mark.parametrize( "col, val", - [["datetime", Timestamp("20130101")], ["timedelta", pd.Timedelta("1 min")]], + [["datetime", Timestamp("20130101")], ["timedelta", + pd.Timedelta("1 min")]], ) def test_applymap_datetimelike(col, val): # datetime/timedelta @@ -576,7 +618,11 @@ def test_applymap_datetimelike(col, val): DataFrame(), DataFrame(columns=list("ABC")), DataFrame(index=list("ABC")), - DataFrame({"A": [], "B": [], "C": []}), + DataFrame({ + "A": [], + "B": [], + "C": [] + }), ], ) @pytest.mark.parametrize("func", [round, lambda x: x]) @@ -600,8 +646,7 @@ def test_applymap_na_ignore(float_frame): mask = np.random.randint(0, 2, size=float_frame.shape, dtype=bool) float_frame_with_na[mask] = pd.NA strlen_frame_na_ignore = float_frame_with_na.applymap( - lambda x: len(str(x)), na_action="ignore" - ) + lambda x: len(str(x)), na_action="ignore") strlen_frame_with_na = strlen_frame.copy() strlen_frame_with_na[mask] = pd.NA tm.assert_frame_equal(strlen_frame_na_ignore, strlen_frame_with_na) @@ -620,30 +665,28 @@ def func(x): def test_applymap_box(): # ufunc will not be boxed. Same test cases as the test_map_box - df = DataFrame( - { - "a": [Timestamp("2011-01-01"), Timestamp("2011-01-02")], - "b": [ - Timestamp("2011-01-01", tz="US/Eastern"), - Timestamp("2011-01-02", tz="US/Eastern"), - ], - "c": [pd.Timedelta("1 days"), pd.Timedelta("2 days")], - "d": [ - pd.Period("2011-01-01", freq="M"), - pd.Period("2011-01-02", freq="M"), - ], - } - ) + df = DataFrame({ + "a": [Timestamp("2011-01-01"), + Timestamp("2011-01-02")], + "b": [ + Timestamp("2011-01-01", tz="US/Eastern"), + Timestamp("2011-01-02", tz="US/Eastern"), + ], + "c": [pd.Timedelta("1 days"), + pd.Timedelta("2 days")], + "d": [ + pd.Period("2011-01-01", freq="M"), + pd.Period("2011-01-02", freq="M"), + ], + }) result = df.applymap(lambda x: type(x).__name__) - expected = DataFrame( - { - "a": ["Timestamp", "Timestamp"], - "b": ["Timestamp", "Timestamp"], - "c": ["Timedelta", "Timedelta"], - "d": ["Period", "Period"], - } - ) + expected = DataFrame({ + "a": ["Timestamp", "Timestamp"], + "b": ["Timestamp", "Timestamp"], + "c": ["Timedelta", "Timedelta"], + "d": ["Period", "Period"], + }) tm.assert_frame_equal(result, expected) @@ -661,14 +704,14 @@ def test_frame_apply_dont_convert_datetime64(): def test_apply_non_numpy_dtype(): # GH 12244 - df = DataFrame({"dt": date_range("2015-01-01", periods=3, tz="Europe/Brussels")}) + df = DataFrame( + {"dt": date_range("2015-01-01", periods=3, tz="Europe/Brussels")}) result = df.apply(lambda x: x) tm.assert_frame_equal(result, df) result = df.apply(lambda x: x + pd.Timedelta("1day")) expected = DataFrame( - {"dt": date_range("2015-01-02", periods=3, tz="Europe/Brussels")} - ) + {"dt": date_range("2015-01-02", periods=3, tz="Europe/Brussels")}) tm.assert_frame_equal(result, expected) @@ -695,17 +738,15 @@ def apply_list(row): df = DataFrame(np.zeros((4, 4)), columns=list("ABCD")) result = getattr(df, op)(apply_list, axis=1) - expected = Series( - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] - ) + expected = Series([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]]) tm.assert_series_equal(result, expected) def test_apply_noreduction_tzaware_object(): # https://github.com/pandas-dev/pandas/issues/31505 - expected = DataFrame( - {"foo": [Timestamp("2020", tz="UTC")]}, dtype="datetime64[ns, UTC]" - ) + expected = DataFrame({"foo": [Timestamp("2020", tz="UTC")]}, + dtype="datetime64[ns, UTC]") result = expected.apply(lambda x: x) tm.assert_frame_equal(result, expected) result = expected.apply(lambda x: x.copy()) @@ -774,7 +815,9 @@ def non_reducing_function(val): def test_apply_with_byte_string(): # GH 34529 df = DataFrame(np.array([b"abcd", b"efgh"]), columns=["col"]) - expected = DataFrame(np.array([b"abcd", b"efgh"]), columns=["col"], dtype=object) + expected = DataFrame(np.array([b"abcd", b"efgh"]), + columns=["col"], + dtype=object) # After we make the apply we expect a dataframe just # like the original but with the object datatype result = df.apply(lambda x: x.astype("object")) @@ -789,8 +832,7 @@ def test_apply_category_equalness(val): result = df.a.apply(lambda x: x == val) expected = Series( - [np.NaN if pd.isnull(x) else x == val for x in df_values], name="a" - ) + [np.NaN if pd.isnull(x) else x == val for x in df_values], name="a") tm.assert_series_equal(result, expected) @@ -836,8 +878,7 @@ def test_with_dictlike_columns_with_datetime(): df["author"] = ["X", "Y", "Z"] df["publisher"] = ["BBC", "NBC", "N24"] df["date"] = pd.to_datetime( - ["17-10-2010 07:15:30", "13-05-2011 08:20:35", "15-01-2013 09:09:09"] - ) + ["17-10-2010 07:15:30", "13-05-2011 08:20:35", "15-01-2013 09:09:09"]) result = df.apply(lambda x: {}, axis=1) expected = Series([{}, {}, {}]) tm.assert_series_equal(result, expected) @@ -846,7 +887,9 @@ def test_with_dictlike_columns_with_datetime(): def test_with_dictlike_columns_with_infer(): # GH 17602 df = DataFrame([[1, 2], [1, 2]], columns=["a", "b"]) - result = df.apply(lambda x: {"s": x["a"] + x["b"]}, axis=1, result_type="expand") + result = df.apply(lambda x: {"s": x["a"] + x["b"]}, + axis=1, + result_type="expand") expected = DataFrame({"s": [3, 3]}) tm.assert_frame_equal(result, expected) @@ -854,19 +897,19 @@ def test_with_dictlike_columns_with_infer(): Timestamp("2017-05-01 00:00:00"), Timestamp("2017-05-02 00:00:00"), ] - result = df.apply(lambda x: {"s": x["a"] + x["b"]}, axis=1, result_type="expand") + result = df.apply(lambda x: {"s": x["a"] + x["b"]}, + axis=1, + result_type="expand") tm.assert_frame_equal(result, expected) def test_with_listlike_columns(): # GH 17348 - df = DataFrame( - { - "a": Series(np.random.randn(4)), - "b": ["a", "list", "of", "words"], - "ts": date_range("2016-10-01", periods=4, freq="H"), - } - ) + df = DataFrame({ + "a": Series(np.random.randn(4)), + "b": ["a", "list", "of", "words"], + "ts": date_range("2016-10-01", periods=4, freq="H"), + }) result = df[["a", "b"]].apply(tuple, axis=1) expected = Series([t[1:] for t in df[["a", "b"]].itertuples()]) @@ -879,10 +922,14 @@ def test_with_listlike_columns(): def test_with_listlike_columns_returning_list(): # GH 18919 - df = DataFrame({"x": Series([["a", "b"], ["q"]]), "y": Series([["z"], ["q", "t"]])}) + df = DataFrame({ + "x": Series([["a", "b"], ["q"]]), + "y": Series([["z"], ["q", "t"]]) + }) df.index = MultiIndex.from_tuples([("i0", "j0"), ("i1", "j1")]) - result = df.apply(lambda row: [el for el in row["x"] if el in row["y"]], axis=1) + result = df.apply(lambda row: [el for el in row["x"] if el in row["y"]], + axis=1) expected = Series([[], ["q"]], index=df.index) tm.assert_series_equal(result, expected) @@ -890,16 +937,14 @@ def test_with_listlike_columns_returning_list(): def test_infer_output_shape_columns(): # GH 18573 - df = DataFrame( - { - "number": [1.0, 2.0], - "string": ["foo", "bar"], - "datetime": [ - Timestamp("2017-11-29 03:30:00"), - Timestamp("2017-11-29 03:45:00"), - ], - } - ) + df = DataFrame({ + "number": [1.0, 2.0], + "string": ["foo", "bar"], + "datetime": [ + Timestamp("2017-11-29 03:30:00"), + Timestamp("2017-11-29 03:45:00"), + ], + }) result = df.apply(lambda row: (row.number, row.string), axis=1) expected = Series([(t.number, t.string) for t in df.itertuples()]) tm.assert_series_equal(result, expected) @@ -931,19 +976,17 @@ def test_infer_output_shape_listlike_columns_np_func(val): def test_infer_output_shape_listlike_columns_with_timestamp(): # GH 17892 - df = DataFrame( - { - "a": [ - Timestamp("2010-02-01"), - Timestamp("2010-02-04"), - Timestamp("2010-02-05"), - Timestamp("2010-02-06"), - ], - "b": [9, 5, 4, 3], - "c": [5, 3, 4, 2], - "d": [1, 2, 3, 4], - } - ) + df = DataFrame({ + "a": [ + Timestamp("2010-02-01"), + Timestamp("2010-02-04"), + Timestamp("2010-02-05"), + Timestamp("2010-02-06"), + ], + "b": [9, 5, 4, 3], + "c": [5, 3, 4, 2], + "d": [1, 2, 3, 4], + }) def fun(x): return (1, 2) @@ -969,14 +1012,16 @@ def test_consistent_names(int_frame_const_col): df = int_frame_const_col result = df.apply( - lambda x: Series([1, 2, 3], index=["test", "other", "cols"]), axis=1 - ) - expected = int_frame_const_col.rename( - columns={"A": "test", "B": "other", "C": "cols"} - ) + lambda x: Series([1, 2, 3], index=["test", "other", "cols"]), axis=1) + expected = int_frame_const_col.rename(columns={ + "A": "test", + "B": "other", + "C": "cols" + }) tm.assert_frame_equal(result, expected) - result = df.apply(lambda x: Series([1, 2], index=["test", "other"]), axis=1) + result = df.apply(lambda x: Series([1, 2], index=["test", "other"]), + axis=1) expected = expected[["test", "other"]] tm.assert_frame_equal(result, expected) @@ -1017,9 +1062,9 @@ def test_result_type_broadcast_series_func(int_frame_const_col): # path we take in the code df = int_frame_const_col columns = ["other", "col", "names"] - result = df.apply( - lambda x: Series([1, 2, 3], index=columns), axis=1, result_type="broadcast" - ) + result = df.apply(lambda x: Series([1, 2, 3], index=columns), + axis=1, + result_type="broadcast") expected = df.copy() tm.assert_frame_equal(result, expected) @@ -1048,7 +1093,10 @@ def test_result_type_series_result_other_index(int_frame_const_col): @pytest.mark.parametrize( "box", - [lambda x: list(x), lambda x: tuple(x), lambda x: np.array(x, dtype="int64")], + [ + lambda x: list(x), lambda x: tuple(x), + lambda x: np.array(x, dtype="int64") + ], ids=["list", "tuple", "array"], ) def test_consistency_for_boxed(box, int_frame_const_col): @@ -1081,9 +1129,11 @@ def test_agg_transform(axis, float_frame): result = float_frame.apply([np.sqrt], axis=axis) expected = f_sqrt.copy() if axis in {0, "index"}: - expected.columns = MultiIndex.from_product([float_frame.columns, ["sqrt"]]) + expected.columns = MultiIndex.from_product( + [float_frame.columns, ["sqrt"]]) else: - expected.index = MultiIndex.from_product([float_frame.index, ["sqrt"]]) + expected.index = MultiIndex.from_product( + [float_frame.index, ["sqrt"]]) tm.assert_frame_equal(result, expected) # multiple items in list @@ -1093,12 +1143,10 @@ def test_agg_transform(axis, float_frame): expected = zip_frames([f_abs, f_sqrt], axis=other_axis) if axis in {0, "index"}: expected.columns = MultiIndex.from_product( - [float_frame.columns, ["absolute", "sqrt"]] - ) + [float_frame.columns, ["absolute", "sqrt"]]) else: expected.index = MultiIndex.from_product( - [float_frame.index, ["absolute", "sqrt"]] - ) + [float_frame.index, ["absolute", "sqrt"]]) tm.assert_frame_equal(result, expected) @@ -1107,9 +1155,12 @@ def test_demo(): df = DataFrame({"A": range(5), "B": 5}) result = df.agg(["min", "max"]) - expected = DataFrame( - {"A": [0, 4], "B": [5, 5]}, columns=["A", "B"], index=["min", "max"] - ) + expected = DataFrame({ + "A": [0, 4], + "B": [5, 5] + }, + columns=["A", "B"], + index=["min", "max"]) tm.assert_frame_equal(result, expected) @@ -1118,7 +1169,10 @@ def test_demo_dict_agg(): df = DataFrame({"A": range(5), "B": 5}) result = df.agg({"A": ["min", "max"], "B": ["sum", "max"]}) expected = DataFrame( - {"A": [4.0, 0.0, np.nan], "B": [5.0, np.nan, 25.0]}, + { + "A": [4.0, 0.0, np.nan], + "B": [5.0, np.nan, 25.0] + }, columns=["A", "B"], index=["max", "min", "sum"], ) @@ -1143,14 +1197,12 @@ def test_agg_with_name_as_column_name(): def test_agg_multiple_mixed_no_warning(): # GH 20909 - mdf = DataFrame( - { - "A": [1, 2, 3], - "B": [1.0, 2.0, 3.0], - "C": ["foo", "bar", "baz"], - "D": date_range("20130101", periods=3), - } - ) + mdf = DataFrame({ + "A": [1, 2, 3], + "B": [1.0, 2.0, 3.0], + "C": ["foo", "bar", "baz"], + "D": date_range("20130101", periods=3), + }) expected = DataFrame( { "A": [1, 6], @@ -1162,15 +1214,13 @@ def test_agg_multiple_mixed_no_warning(): ) # sorted index with tm.assert_produces_warning( - FutureWarning, match=r"\['D'\] did not aggregate successfully" - ): + FutureWarning, match=r"\['D'\] did not aggregate successfully"): result = mdf.agg(["min", "sum"]) tm.assert_frame_equal(result, expected) with tm.assert_produces_warning( - FutureWarning, match=r"\['D'\] did not aggregate successfully" - ): + FutureWarning, match=r"\['D'\] did not aggregate successfully"): result = mdf[["D", "C", "B", "A"]].agg(["sum", "min"]) # GH40420: the result of .agg should have an index that is sorted @@ -1213,12 +1263,12 @@ def test_agg_reduce(axis, float_frame): # dict input with lists func = {name1: ["mean"], name2: ["sum"]} result = float_frame.agg(func, axis=axis) - expected = DataFrame( - { - name1: Series([float_frame.loc(other_axis)[name1].mean()], index=["mean"]), - name2: Series([float_frame.loc(other_axis)[name2].sum()], index=["sum"]), - } - ) + expected = DataFrame({ + name1: + Series([float_frame.loc(other_axis)[name1].mean()], index=["mean"]), + name2: + Series([float_frame.loc(other_axis)[name2].sum()], index=["sum"]), + }) expected = expected.T if axis in {1, "columns"} else expected tm.assert_frame_equal(result, expected) @@ -1227,14 +1277,16 @@ def test_agg_reduce(axis, float_frame): result = float_frame.agg(func, axis=axis) expected = pd.concat( { - name1: Series( + name1: + Series( [ float_frame.loc(other_axis)[name1].mean(), float_frame.loc(other_axis)[name1].sum(), ], index=["mean", "sum"], ), - name2: Series( + name2: + Series( [ float_frame.loc(other_axis)[name2].sum(), float_frame.loc(other_axis)[name2].max(), @@ -1251,14 +1303,12 @@ def test_agg_reduce(axis, float_frame): def test_nuiscance_columns(): # GH 15015 - df = DataFrame( - { - "A": [1, 2, 3], - "B": [1.0, 2.0, 3.0], - "C": ["foo", "bar", "baz"], - "D": date_range("20130101", periods=3), - } - ) + df = DataFrame({ + "A": [1, 2, 3], + "B": [1.0, 2.0, 3.0], + "C": ["foo", "bar", "baz"], + "D": date_range("20130101", periods=3), + }) result = df.agg("min") expected = Series([1, 1.0, "bar", Timestamp("20130101")], index=df.columns) @@ -1278,12 +1328,11 @@ def test_nuiscance_columns(): tm.assert_series_equal(result, expected) with tm.assert_produces_warning( - FutureWarning, match=r"\['D'\] did not aggregate successfully" - ): + FutureWarning, match=r"\['D'\] did not aggregate successfully"): result = df.agg(["sum"]) - expected = DataFrame( - [[6, 6.0, "foobarbaz"]], index=["sum"], columns=["A", "B", "C"] - ) + expected = DataFrame([[6, 6.0, "foobarbaz"]], + index=["sum"], + columns=["A", "B", "C"]) tm.assert_frame_equal(result, expected) @@ -1294,9 +1343,11 @@ def test_non_callable_aggregates(how): # 'size' is a property of frame/series # validate that this is working # GH 39116 - expand to apply - df = DataFrame( - {"A": [None, 2, 3], "B": [1.0, np.nan, 3.0], "C": ["foo", None, "bar"]} - ) + df = DataFrame({ + "A": [None, 2, 3], + "B": [1.0, np.nan, 3.0], + "C": ["foo", None, "bar"] + }) # Function aggregate result = getattr(df, how)({"A": "count"}) @@ -1312,16 +1363,25 @@ def test_non_callable_aggregates(how): # Mix function and non-function aggs result1 = getattr(df, how)(["count", "size"]) - result2 = getattr(df, how)( - {"A": ["count", "size"], "B": ["count", "size"], "C": ["count", "size"]} - ) - expected = DataFrame( - { - "A": {"count": 2, "size": 3}, - "B": {"count": 2, "size": 3}, - "C": {"count": 2, "size": 3}, - } - ) + result2 = getattr(df, how)({ + "A": ["count", "size"], + "B": ["count", "size"], + "C": ["count", "size"] + }) + expected = DataFrame({ + "A": { + "count": 2, + "size": 3 + }, + "B": { + "count": 2, + "size": 3 + }, + "C": { + "count": 2, + "size": 3 + }, + }) tm.assert_frame_equal(result1, result2, check_like=True) tm.assert_frame_equal(result2, expected, check_like=True) @@ -1336,9 +1396,11 @@ def test_non_callable_aggregates(how): @pytest.mark.parametrize("how", ["agg", "apply"]) def test_size_as_str(how, axis): # GH 39934 - df = DataFrame( - {"A": [None, 2, 3], "B": [1.0, np.nan, 3.0], "C": ["foo", None, "bar"]} - ) + df = DataFrame({ + "A": [None, 2, 3], + "B": [1.0, np.nan, 3.0], + "C": ["foo", None, "bar"] + }) # Just a string attribute arg same as calling df.arg # on the columns result = getattr(df, how)("size", axis=axis) @@ -1351,7 +1413,11 @@ def test_size_as_str(how, axis): def test_agg_listlike_result(): # GH-29587 user defined function returning list-likes - df = DataFrame({"A": [2, 2, 3], "B": [1.5, np.nan, 1.5], "C": ["foo", None, "bar"]}) + df = DataFrame({ + "A": [2, 2, 3], + "B": [1.5, np.nan, 1.5], + "C": ["foo", None, "bar"] + }) def func(group_col): return list(group_col.dropna().unique()) @@ -1372,13 +1438,26 @@ def func(group_col): ((1, 2, 3), {}), ((8, 7, 15), {}), ((1, 2), {}), - ((1,), {"b": 2}), - ((), {"a": 1, "b": 2}), - ((), {"a": 2, "b": 1}), - ((), {"a": 1, "b": 2, "c": 3}), + ((1, ), { + "b": 2 + }), + ((), { + "a": 1, + "b": 2 + }), + ((), { + "a": 2, + "b": 1 + }), + ((), { + "a": 1, + "b": 2, + "c": 3 + }), ], ) def test_agg_args_kwargs(axis, args, kwargs): + def f(x, a, b, c=3): return x.sum() + (a + b) / c @@ -1419,7 +1498,11 @@ def test_apply_datetime_tz_issue(): tm.assert_series_equal(result, expected) -@pytest.mark.parametrize("df", [DataFrame({"A": ["a", None], "B": ["c", "d"]})]) +@pytest.mark.parametrize("df", + [DataFrame({ + "A": ["a", None], + "B": ["c", "d"] + })]) @pytest.mark.parametrize("method", ["min", "max", "sum"]) def test_consistency_of_aggregates_of_columns_with_missing_values(df, method): # GH 16832 @@ -1427,7 +1510,8 @@ def test_consistency_of_aggregates_of_columns_with_missing_values(df, method): none_in_first_column_result = getattr(df[["A", "B"]], method)() none_in_second_column_result = getattr(df[["B", "A"]], method)() - tm.assert_series_equal(none_in_first_column_result, none_in_second_column_result) + tm.assert_series_equal(none_in_first_column_result, + none_in_second_column_result) @pytest.mark.parametrize("col", [1, 1.0, True, "a", np.nan]) @@ -1467,7 +1551,8 @@ def func(row): def test_apply_empty_list_reduce(): # GH#35683 get columns correct - df = DataFrame([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]], columns=["a", "b"]) + df = DataFrame([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]], + columns=["a", "b"]) result = df.apply(lambda x: [], result_type="reduce") expected = Series({"a": [], "b": []}, dtype=object) @@ -1478,9 +1563,11 @@ def test_apply_no_suffix_index(): # GH36189 pdf = DataFrame([[4, 9]] * 3, columns=["A", "B"]) result = pdf.apply(["sum", lambda x: x.sum(), lambda x: x.sum()]) - expected = DataFrame( - {"A": [12, 12, 12], "B": [27, 27, 27]}, index=["sum", "", ""] - ) + expected = DataFrame({ + "A": [12, 12, 12], + "B": [27, 27, 27] + }, + index=["sum", "", ""]) tm.assert_frame_equal(result, expected) @@ -1513,8 +1600,7 @@ def foo(s): aggs = ["sum", foo, "count", "min"] with tm.assert_produces_warning( - FutureWarning, match=r"\['item'\] did not aggregate successfully" - ): + FutureWarning, match=r"\['item'\] did not aggregate successfully"): result = df.agg(aggs) expected = DataFrame( { diff --git a/pandas/tests/apply/test_frame_transform.py b/pandas/tests/apply/test_frame_transform.py index 9caae8e616b36..2f37970b0d9db 100644 --- a/pandas/tests/apply/test_frame_transform.py +++ b/pandas/tests/apply/test_frame_transform.py @@ -51,7 +51,8 @@ def test_transform_listlike(axis, float_frame, ops, names): with np.errstate(all="ignore"): expected = zip_frames([op(float_frame) for op in ops], axis=other_axis) if axis in {0, "index"}: - expected.columns = MultiIndex.from_product([float_frame.columns, names]) + expected.columns = MultiIndex.from_product( + [float_frame.columns, names]) else: expected.index = MultiIndex.from_product([float_frame.index, names]) result = float_frame.transform(ops, axis=axis) @@ -62,7 +63,8 @@ def test_transform_listlike(axis, float_frame, ops, names): def test_transform_empty_listlike(float_frame, ops, frame_or_series): obj = unpack_obj(float_frame, frame_or_series, 0) - with pytest.raises(ValueError, match="No transform functions were provided"): + with pytest.raises(ValueError, + match="No transform functions were provided"): obj.transform(ops) @@ -85,7 +87,8 @@ def test_transform_dictlike_mixed(): result = df.transform({"b": ["sqrt", "abs"], "c": "sqrt"}) expected = DataFrame( [[1.0, 1, 1.0], [2.0, 4, 2.0]], - columns=MultiIndex([("b", "c"), ("sqrt", "abs")], [(0, 0, 1), (0, 1, 0)]), + columns=MultiIndex([("b", "c"), ("sqrt", "abs")], [(0, 0, 1), + (0, 1, 0)]), ) tm.assert_frame_equal(result, expected) @@ -94,17 +97,32 @@ def test_transform_dictlike_mixed(): "ops", [ {}, - {"A": []}, - {"A": [], "B": "cumsum"}, - {"A": "cumsum", "B": []}, - {"A": [], "B": ["cumsum"]}, - {"A": ["cumsum"], "B": []}, + { + "A": [] + }, + { + "A": [], + "B": "cumsum" + }, + { + "A": "cumsum", + "B": [] + }, + { + "A": [], + "B": ["cumsum"] + }, + { + "A": ["cumsum"], + "B": [] + }, ], ) def test_transform_empty_dictlike(float_frame, ops, frame_or_series): obj = unpack_obj(float_frame, frame_or_series, 0) - with pytest.raises(ValueError, match="No transform functions were provided"): + with pytest.raises(ValueError, + match="No transform functions were provided"): obj.transform(ops) @@ -127,7 +145,9 @@ def func(x): wont_fail = ["ffill", "bfill", "fillna", "pad", "backfill", "shift"] -frame_kernels_raise = [x for x in frame_transform_kernels if x not in wont_fail] +frame_kernels_raise = [ + x for x in frame_transform_kernels if x not in wont_fail +] @pytest.mark.parametrize("op", [*frame_kernels_raise, lambda x: x + 1]) @@ -136,17 +156,18 @@ def test_transform_bad_dtype(op, frame_or_series, request): if op == "rank": request.node.add_marker( pytest.mark.xfail( - raises=ValueError, reason="GH 40418: rank does not raise a TypeError" - ) - ) + raises=ValueError, + reason="GH 40418: rank does not raise a TypeError")) - obj = DataFrame({"A": 3 * [object]}) # DataFrame that will fail on most transforms + obj = DataFrame({"A": 3 * [object] + }) # DataFrame that will fail on most transforms obj = tm.get_obj(obj, frame_or_series) # tshift is deprecated warn = None if op != "tshift" else FutureWarning with tm.assert_produces_warning(warn): - with pytest.raises(TypeError, match="unsupported operand|not supported"): + with pytest.raises(TypeError, + match="unsupported operand|not supported"): obj.transform(op) with pytest.raises(TypeError, match="Transform function failed"): obj.transform([op]) diff --git a/pandas/tests/apply/test_str.py b/pandas/tests/apply/test_str.py index 8a5d09cc1c9b5..768cb9612eda5 100644 --- a/pandas/tests/apply/test_str.py +++ b/pandas/tests/apply/test_str.py @@ -38,8 +38,7 @@ def test_apply_with_string_funcs(request, float_frame, func, args, kwds, how): raises=TypeError, reason="agg/apply signature mismatch - agg passes 2nd " "argument to func", - ) - ) + )) result = getattr(float_frame, how)(func, *args, **kwds) expected = getattr(float_frame, func)(*args, **kwds) tm.assert_series_equal(result, expected) @@ -60,15 +59,13 @@ def test_apply_np_reducer(op, how): result = getattr(float_frame, how)(op) # pandas ddof defaults to 1, numpy to 0 kwargs = {"ddof": 1} if op in ("std", "var") else {} - expected = Series( - getattr(np, op)(float_frame, axis=0, **kwargs), index=float_frame.columns - ) + expected = Series(getattr(np, op)(float_frame, axis=0, **kwargs), + index=float_frame.columns) tm.assert_series_equal(result, expected) @pytest.mark.parametrize( - "op", ["abs", "ceil", "cos", "cumsum", "exp", "log", "sqrt", "square"] -) + "op", ["abs", "ceil", "cos", "cumsum", "exp", "log", "sqrt", "square"]) @pytest.mark.parametrize("how", ["transform", "apply"]) def test_apply_np_transformer(float_frame, op, how): # GH 39116 @@ -159,9 +156,8 @@ def test_agg_cython_table_series(series, func, expected): ("cumsum", Series([np.nan, 1, 3, 6])), ], ), - tm.get_cython_table_params( - Series("a b c".split()), [("cumsum", Series(["a", "ab", "abc"]))] - ), + tm.get_cython_table_params(Series("a b c".split()), + [("cumsum", Series(["a", "ab", "abc"]))]), ), ) def test_agg_cython_table_transform_series(series, func, expected): @@ -218,9 +214,8 @@ def test_agg_cython_table_frame(df, func, expected, axis): @pytest.mark.parametrize( "df, func, expected", chain( - tm.get_cython_table_params( - DataFrame(), [("cumprod", DataFrame()), ("cumsum", DataFrame())] - ), + tm.get_cython_table_params(DataFrame(), [("cumprod", DataFrame()), + ("cumsum", DataFrame())]), tm.get_cython_table_params( DataFrame([[np.nan, 1], [1, 2]]), [ @@ -283,7 +278,8 @@ def test_transform_groupby_kernel_frame(axis, float_frame, op): tm.assert_frame_equal(result2, expected2) -@pytest.mark.parametrize("method", ["abs", "shift", "pct_change", "cumsum", "rank"]) +@pytest.mark.parametrize("method", + ["abs", "shift", "pct_change", "cumsum", "rank"]) def test_transform_method_name(method): # GH 19760 df = DataFrame({"A": [-1, 2]}) diff --git a/pandas/tests/frame/methods/test_interpolate.py b/pandas/tests/frame/methods/test_interpolate.py index 59e941ff9e4ab..5f8673b945c43 100644 --- a/pandas/tests/frame/methods/test_interpolate.py +++ b/pandas/tests/frame/methods/test_interpolate.py @@ -12,10 +12,13 @@ class TestDataFrameInterpolate: - def test_interpolate_inplace(self, frame_or_series, using_array_manager, request): + + def test_interpolate_inplace(self, frame_or_series, using_array_manager, + request): # GH#44749 if using_array_manager and frame_or_series is DataFrame: - mark = pytest.mark.xfail(reason=".values-based in-place check is invalid") + mark = pytest.mark.xfail( + reason=".values-based in-place check is invalid") request.node.add_marker(mark) obj = frame_or_series([1, np.nan, 2]) @@ -30,22 +33,18 @@ def test_interpolate_inplace(self, frame_or_series, using_array_manager, request assert orig.squeeze()[1] == 1.5 def test_interp_basic(self): - df = DataFrame( - { - "A": [1, 2, np.nan, 4], - "B": [1, 4, 9, np.nan], - "C": [1, 2, 3, 5], - "D": list("abcd"), - } - ) - expected = DataFrame( - { - "A": [1.0, 2.0, 3.0, 4.0], - "B": [1.0, 4.0, 9.0, 9.0], - "C": [1, 2, 3, 5], - "D": list("abcd"), - } - ) + df = DataFrame({ + "A": [1, 2, np.nan, 4], + "B": [1, 4, 9, np.nan], + "C": [1, 2, 3, 5], + "D": list("abcd"), + }) + expected = DataFrame({ + "A": [1.0, 2.0, 3.0, 4.0], + "B": [1.0, 4.0, 9.0, 9.0], + "C": [1, 2, 3, 5], + "D": list("abcd"), + }) result = df.interpolate() tm.assert_frame_equal(result, expected) @@ -64,22 +63,18 @@ def test_interp_basic(self): assert np.shares_memory(df["D"]._values, dvalues) def test_interp_basic_with_non_range_index(self): - df = DataFrame( - { - "A": [1, 2, np.nan, 4], - "B": [1, 4, 9, np.nan], - "C": [1, 2, 3, 5], - "D": list("abcd"), - } - ) - expected = DataFrame( - { - "A": [1.0, 2.0, 3.0, 4.0], - "B": [1.0, 4.0, 9.0, 9.0], - "C": [1, 2, 3, 5], - "D": list("abcd"), - } - ) + df = DataFrame({ + "A": [1, 2, np.nan, 4], + "B": [1, 4, 9, np.nan], + "C": [1, 2, 3, 5], + "D": list("abcd"), + }) + expected = DataFrame({ + "A": [1.0, 2.0, 3.0, 4.0], + "B": [1.0, 4.0, 9.0, 9.0], + "C": [1, 2, 3, 5], + "D": list("abcd"), + }) result = df.set_index("C").interpolate() expected = df.set_index("C") @@ -96,33 +91,28 @@ def test_interp_empty(self): tm.assert_frame_equal(result, expected) def test_interp_bad_method(self): - df = DataFrame( - { - "A": [1, 2, np.nan, 4], - "B": [1, 4, 9, np.nan], - "C": [1, 2, 3, 5], - "D": list("abcd"), - } - ) + df = DataFrame({ + "A": [1, 2, np.nan, 4], + "B": [1, 4, 9, np.nan], + "C": [1, 2, 3, 5], + "D": list("abcd"), + }) msg = ( r"method must be one of \['linear', 'time', 'index', 'values', " r"'nearest', 'zero', 'slinear', 'quadratic', 'cubic', " r"'barycentric', 'krogh', 'spline', 'polynomial', " r"'from_derivatives', 'piecewise_polynomial', 'pchip', 'akima', " - r"'cubicspline'\]. Got 'not_a_method' instead." - ) + r"'cubicspline'\]. Got 'not_a_method' instead.") with pytest.raises(ValueError, match=msg): df.interpolate(method="not_a_method") def test_interp_combo(self): - df = DataFrame( - { - "A": [1.0, 2.0, np.nan, 4.0], - "B": [1, 4, 9, np.nan], - "C": [1, 2, 3, 5], - "D": list("abcd"), - } - ) + df = DataFrame({ + "A": [1.0, 2.0, np.nan, 4.0], + "B": [1, 4, 9, np.nan], + "C": [1, 2, 3, 5], + "D": list("abcd"), + }) result = df["A"].interpolate() expected = Series([1.0, 2.0, 3.0, 4.0], name="A") @@ -137,16 +127,16 @@ def test_interp_nan_idx(self): df = df.set_index("A") msg = ( "Interpolation with NaNs in the index has not been implemented. " - "Try filling those NaNs before interpolating." - ) + "Try filling those NaNs before interpolating.") with pytest.raises(NotImplementedError, match=msg): df.interpolate(method="values") @td.skip_if_no_scipy def test_interp_various(self): - df = DataFrame( - {"A": [1, 2, np.nan, 4, 5, np.nan, 7], "C": [1, 2, 3, 5, 8, 13, 21]} - ) + df = DataFrame({ + "A": [1, 2, np.nan, 4, 5, np.nan, 7], + "C": [1, 2, 3, 5, 8, 13, 21] + }) df = df.set_index("C") expected = df.copy() result = df.interpolate(method="polynomial", order=1) @@ -183,9 +173,10 @@ def test_interp_various(self): @td.skip_if_no_scipy def test_interp_alt_scipy(self): - df = DataFrame( - {"A": [1, 2, np.nan, 4, 5, np.nan, 7], "C": [1, 2, 3, 5, 8, 13, 21]} - ) + df = DataFrame({ + "A": [1, 2, np.nan, 4, 5, np.nan, 7], + "C": [1, 2, 3, 5, 8, 13, 21] + }) result = df.interpolate(method="barycentric") expected = df.copy() expected.loc[2, "A"] = 3 @@ -207,15 +198,13 @@ def test_interp_alt_scipy(self): tm.assert_frame_equal(result, expected) def test_interp_rowwise(self): - df = DataFrame( - { - 0: [1, 2, np.nan, 4], - 1: [2, 3, 4, np.nan], - 2: [np.nan, 4, 5, 6], - 3: [4, np.nan, 6, 7], - 4: [1, 2, 3, 4], - } - ) + df = DataFrame({ + 0: [1, 2, np.nan, 4], + 1: [2, 3, 4, np.nan], + 2: [np.nan, 4, 5, 6], + 3: [4, np.nan, 6, 7], + 4: [1, 2, 3, 4], + }) result = df.interpolate(axis=1) expected = df.copy() expected.loc[3, 1] = 5 @@ -249,22 +238,21 @@ def test_interp_axis_names(self, axis_name, axis_number): tm.assert_frame_equal(result, expected) def test_rowwise_alt(self): - df = DataFrame( - { - 0: [0, 0.5, 1.0, np.nan, 4, 8, np.nan, np.nan, 64], - 1: [1, 2, 3, 4, 3, 2, 1, 0, -1], - } - ) + df = DataFrame({ + 0: [0, 0.5, 1.0, np.nan, 4, 8, np.nan, np.nan, 64], + 1: [1, 2, 3, 4, 3, 2, 1, 0, -1], + }) df.interpolate(axis=0) # TODO: assert something? @pytest.mark.parametrize( - "check_scipy", [False, pytest.param(True, marks=td.skip_if_no_scipy)] - ) + "check_scipy", + [False, pytest.param(True, marks=td.skip_if_no_scipy)]) def test_interp_leading_nans(self, check_scipy): - df = DataFrame( - {"A": [np.nan, np.nan, 0.5, 0.25, 0], "B": [np.nan, -3, -3.5, np.nan, -4]} - ) + df = DataFrame({ + "A": [np.nan, np.nan, 0.5, 0.25, 0], + "B": [np.nan, -3, -3.5, np.nan, -4] + }) result = df.interpolate() expected = df.copy() expected.loc[3, "B"] = -3.75 @@ -275,31 +263,25 @@ def test_interp_leading_nans(self, check_scipy): tm.assert_frame_equal(result, expected) def test_interp_raise_on_only_mixed(self, axis): - df = DataFrame( - { - "A": [1, 2, np.nan, 4], - "B": ["a", "b", "c", "d"], - "C": [np.nan, 2, 5, 7], - "D": [np.nan, np.nan, 9, 9], - "E": [1, 2, 3, 4], - } - ) - msg = ( - "Cannot interpolate with all object-dtype columns " - "in the DataFrame. Try setting at least one " - "column to a numeric dtype." - ) + df = DataFrame({ + "A": [1, 2, np.nan, 4], + "B": ["a", "b", "c", "d"], + "C": [np.nan, 2, 5, 7], + "D": [np.nan, np.nan, 9, 9], + "E": [1, 2, 3, 4], + }) + msg = ("Cannot interpolate with all object-dtype columns " + "in the DataFrame. Try setting at least one " + "column to a numeric dtype.") with pytest.raises(TypeError, match=msg): df.astype("object").interpolate(axis=axis) def test_interp_raise_on_all_object_dtype(self): # GH 22985 df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, dtype="object") - msg = ( - "Cannot interpolate with all object-dtype columns " - "in the DataFrame. Try setting at least one " - "column to a numeric dtype." - ) + msg = ("Cannot interpolate with all object-dtype columns " + "in the DataFrame. Try setting at least one " + "column to a numeric dtype.") with pytest.raises(TypeError, match=msg): df.interpolate() @@ -318,32 +300,36 @@ def test_interp_inplace(self): def test_interp_inplace_row(self): # GH 10395 - result = DataFrame( - {"a": [1.0, 2.0, 3.0, 4.0], "b": [np.nan, 2.0, 3.0, 4.0], "c": [3, 2, 2, 2]} - ) + result = DataFrame({ + "a": [1.0, 2.0, 3.0, 4.0], + "b": [np.nan, 2.0, 3.0, 4.0], + "c": [3, 2, 2, 2] + }) expected = result.interpolate(method="linear", axis=1, inplace=False) - return_value = result.interpolate(method="linear", axis=1, inplace=True) + return_value = result.interpolate(method="linear", + axis=1, + inplace=True) assert return_value is None tm.assert_frame_equal(result, expected) def test_interp_ignore_all_good(self): # GH - df = DataFrame( - { - "A": [1, 2, np.nan, 4], - "B": [1, 2, 3, 4], - "C": [1.0, 2.0, np.nan, 4.0], - "D": [1.0, 2.0, 3.0, 4.0], - } - ) - expected = DataFrame( - { - "A": np.array([1, 2, 3, 4], dtype="float64"), - "B": np.array([1, 2, 3, 4], dtype="int64"), - "C": np.array([1.0, 2.0, 3, 4.0], dtype="float64"), - "D": np.array([1.0, 2.0, 3.0, 4.0], dtype="float64"), - } - ) + df = DataFrame({ + "A": [1, 2, np.nan, 4], + "B": [1, 2, 3, 4], + "C": [1.0, 2.0, np.nan, 4.0], + "D": [1.0, 2.0, 3.0, 4.0], + }) + expected = DataFrame({ + "A": + np.array([1, 2, 3, 4], dtype="float64"), + "B": + np.array([1, 2, 3, 4], dtype="int64"), + "C": + np.array([1.0, 2.0, 3, 4.0], dtype="float64"), + "D": + np.array([1.0, 2.0, 3.0, 4.0], dtype="float64"), + }) result = df.interpolate(downcast=None) tm.assert_frame_equal(result, expected) @@ -361,36 +347,38 @@ def test_interp_time_inplace_axis(self): expected = DataFrame(index=idx, columns=idx, data=data) result = expected.interpolate(axis=0, method="time") - return_value = expected.interpolate(axis=0, method="time", inplace=True) + return_value = expected.interpolate(axis=0, + method="time", + inplace=True) assert return_value is None tm.assert_frame_equal(result, expected) - @pytest.mark.parametrize("axis_name, axis_number", [("index", 0), ("columns", 1)]) + @pytest.mark.parametrize("axis_name, axis_number", [("index", 0), + ("columns", 1)]) def test_interp_string_axis(self, axis_name, axis_number): # https://github.com/pandas-dev/pandas/issues/25190 x = np.linspace(0, 100, 1000) y = np.sin(x) - df = DataFrame( - data=np.tile(y, (10, 1)), index=np.arange(10), columns=x - ).reindex(columns=x * 1.005) + df = DataFrame(data=np.tile(y, (10, 1)), + index=np.arange(10), + columns=x).reindex(columns=x * 1.005) result = df.interpolate(method="linear", axis=axis_name) expected = df.interpolate(method="linear", axis=axis_number) tm.assert_frame_equal(result, expected) @pytest.mark.parametrize("method", ["ffill", "bfill", "pad"]) - def test_interp_fillna_methods(self, request, axis, method, using_array_manager): + def test_interp_fillna_methods(self, request, axis, method, + using_array_manager): # GH 12918 if using_array_manager and axis in (1, "columns"): # TODO(ArrayManager) support axis=1 td.mark_array_manager_not_yet_implemented(request) - df = DataFrame( - { - "A": [1.0, 2.0, 3.0, 4.0, np.nan, 5.0], - "B": [2.0, 4.0, 6.0, np.nan, 8.0, 10.0], - "C": [3.0, 6.0, 9.0, np.nan, np.nan, 30.0], - } - ) + df = DataFrame({ + "A": [1.0, 2.0, 3.0, 4.0, np.nan, 5.0], + "B": [2.0, 4.0, 6.0, np.nan, 8.0, 10.0], + "C": [3.0, 6.0, 9.0, np.nan, np.nan, 30.0], + }) expected = df.fillna(axis=axis, method=method) result = df.interpolate(method=method, axis=axis) tm.assert_frame_equal(result, expected) @@ -400,8 +388,7 @@ def test_interpolate_pos_args_deprecation(self): df = DataFrame({"a": [1, 2, 3]}) msg = ( r"In a future version of pandas all arguments of DataFrame.interpolate " - r"except for the argument 'method' will be keyword-only" - ) + r"except for the argument 'method' will be keyword-only") with tm.assert_produces_warning(FutureWarning, match=msg): result = df.interpolate("pad", 0) expected = DataFrame({"a": [1, 2, 3]}) diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index c53547d00dbeb..bc9c06017ce2b 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -47,26 +47,20 @@ def test_union_different_types(index_flat, index_flat2, request): idx1 = index_flat idx2 = index_flat2 - if ( - not idx1.is_unique - and not idx2.is_unique - and idx1.dtype.kind == "i" - and idx2.dtype.kind == "b" - ) or ( - not idx2.is_unique - and not idx1.is_unique - and idx2.dtype.kind == "i" - and idx1.dtype.kind == "b" - ): + if (not idx1.is_unique and not idx2.is_unique and idx1.dtype.kind == "i" + and idx2.dtype.kind == "b") or (not idx2.is_unique + and not idx1.is_unique + and idx2.dtype.kind == "i" + and idx1.dtype.kind == "b"): # Each condition had idx[1|2].is_monotonic_decreasing # but failed when e.g. # idx1 = Index( # [True, True, True, True, True, True, True, True, False, False], dtype='bool' # ) # idx2 = Int64Index([0, 0, 1, 1, 2, 2], dtype='int64') - mark = pytest.mark.xfail( - reason="GH#44000 True==1", raises=ValueError, strict=False - ) + mark = pytest.mark.xfail(reason="GH#44000 True==1", + raises=ValueError, + strict=False) request.node.add_marker(mark) common_dtype = find_common_type([idx1.dtype, idx2.dtype]) @@ -74,19 +68,12 @@ def test_union_different_types(index_flat, index_flat2, request): warn = None if not len(idx1) or not len(idx2): pass - elif ( - idx1.dtype.kind == "c" - and ( - idx2.dtype.kind not in ["i", "u", "f", "c"] - or not isinstance(idx2.dtype, np.dtype) - ) - ) or ( - idx2.dtype.kind == "c" - and ( - idx1.dtype.kind not in ["i", "u", "f", "c"] - or not isinstance(idx1.dtype, np.dtype) - ) - ): + elif (idx1.dtype.kind == "c" and + (idx2.dtype.kind not in ["i", "u", "f", "c"] + or not isinstance(idx2.dtype, np.dtype))) or ( + idx2.dtype.kind == "c" and + (idx1.dtype.kind not in ["i", "u", "f", "c"] + or not isinstance(idx1.dtype, np.dtype))): # complex objects non-sortable warn = RuntimeWarning @@ -152,7 +139,8 @@ def test_compatible_inconsistent_pairs(idx_fact1, idx_fact2): ("Period[D]", "float64", "object"), ], ) -@pytest.mark.parametrize("names", [("foo", "foo", "foo"), ("foo", "bar", None)]) +@pytest.mark.parametrize("names", [("foo", "foo", "foo"), + ("foo", "bar", None)]) def test_union_dtypes(left, right, expected, names): left = pandas_dtype(left) right = pandas_dtype(right) @@ -195,8 +183,8 @@ class TestSetOps: # Set operation tests shared by all indexes in the `index` fixture @pytest.mark.parametrize("case", [0.5, "xxx"]) @pytest.mark.parametrize( - "method", ["intersection", "union", "difference", "symmetric_difference"] - ) + "method", + ["intersection", "union", "difference", "symmetric_difference"]) def test_set_ops_error_cases(self, case, method, index): # non-iterable input msg = "Input must be Index or array-like" @@ -314,7 +302,8 @@ def test_symmetric_difference(self, index): (None, None, None), ], ) - def test_corner_union(self, index_flat_unique, fname, sname, expected_name): + def test_corner_union(self, index_flat_unique, fname, sname, + expected_name): # GH#9943, GH#9862 # Test unions with various name combinations # Do not test MultiIndex or repeats @@ -358,7 +347,8 @@ def test_corner_union(self, index_flat_unique, fname, sname, expected_name): (None, None, None), ], ) - def test_union_unequal(self, index_flat_unique, fname, sname, expected_name): + def test_union_unequal(self, index_flat_unique, fname, sname, + expected_name): index = index_flat_unique # test copy.union(subset) - need sort for unicode and string @@ -378,7 +368,8 @@ def test_union_unequal(self, index_flat_unique, fname, sname, expected_name): (None, None, None), ], ) - def test_corner_intersect(self, index_flat_unique, fname, sname, expected_name): + def test_corner_intersect(self, index_flat_unique, fname, sname, + expected_name): # GH#35847 # Test intersections with various name combinations index = index_flat_unique @@ -421,7 +412,8 @@ def test_corner_intersect(self, index_flat_unique, fname, sname, expected_name): (None, None, None), ], ) - def test_intersect_unequal(self, index_flat_unique, fname, sname, expected_name): + def test_intersect_unequal(self, index_flat_unique, fname, sname, + expected_name): index = index_flat_unique # test copy.intersection(subset) - need sort for unicode and string @@ -485,8 +477,7 @@ def test_intersection_difference_match_empty(self, index, sort): @pytest.mark.parametrize( - "method", ["intersection", "union", "difference", "symmetric_difference"] -) + "method", ["intersection", "union", "difference", "symmetric_difference"]) def test_setop_with_categorical(index_flat, sort, method): # MultiIndex tested separately in tests.indexes.multi.test_setops index = index_flat @@ -684,9 +675,8 @@ def test_intersection_name_preservation(self, index2, keeps_name, sort): "first_name,second_name,expected_name", [("A", "A", "A"), ("A", "B", None), (None, "B", None)], ) - def test_intersection_name_preservation2( - self, index, first_name, second_name, expected_name, sort - ): + def test_intersection_name_preservation2(self, index, first_name, + second_name, expected_name, sort): first = index[5:20] second = index[:10] first.name = first_name @@ -752,8 +742,10 @@ def test_union_identity(self, index, sort): assert (union is first) is (not sort) @pytest.mark.parametrize("index", ["string"], indirect=True) - @pytest.mark.parametrize("second_name,expected", [(None, None), ("name", "name")]) - def test_difference_name_preservation(self, index, second_name, expected, sort): + @pytest.mark.parametrize("second_name,expected", [(None, None), + ("name", "name")]) + def test_difference_name_preservation(self, index, second_name, expected, + sort): first = index[5:20] second = index[:10] answer = index[10:20] @@ -864,6 +856,8 @@ def test_symmetric_difference_non_index(self, sort): assert tm.equalContents(result, expected) assert result.name == "index1" - result = index1.symmetric_difference(index2, result_name="new_name", sort=sort) + result = index1.symmetric_difference(index2, + result_name="new_name", + sort=sort) assert tm.equalContents(result, expected) assert result.name == "new_name" diff --git a/pandas/tests/indexing/multiindex/test_slice.py b/pandas/tests/indexing/multiindex/test_slice.py index 8437cffcd14c0..a3756d6a4d2af 100644 --- a/pandas/tests/indexing/multiindex/test_slice.py +++ b/pandas/tests/indexing/multiindex/test_slice.py @@ -16,58 +16,52 @@ class TestMultiIndexSlicers: + def test_per_axis_per_level_getitem(self): # GH6134 # example test case ix = MultiIndex.from_product( - [_mklbl("A", 5), _mklbl("B", 7), _mklbl("C", 4), _mklbl("D", 2)] - ) + [_mklbl("A", 5), + _mklbl("B", 7), + _mklbl("C", 4), + _mklbl("D", 2)]) df = DataFrame(np.arange(len(ix.to_numpy())), index=ix) result = df.loc[(slice("A1", "A3"), slice(None), ["C1", "C3"]), :] - expected = df.loc[ - [ - ( - a, - b, - c, - d, - ) - for a, b, c, d in df.index.values - if a in ("A1", "A2", "A3") and c in ("C1", "C3") - ] - ] + expected = df.loc[[( + a, + b, + c, + d, + ) for a, b, c, d in df.index.values + if a in ("A1", "A2", "A3") and c in ("C1", "C3")]] tm.assert_frame_equal(result, expected) - expected = df.loc[ - [ - ( - a, - b, - c, - d, - ) - for a, b, c, d in df.index.values - if a in ("A1", "A2", "A3") - and c in ("C1", "C2", "C3") - ] - ] + expected = df.loc[[ + ( + a, + b, + c, + d, + ) for a, b, c, d in df.index.values + if a in ("A1", "A2", "A3") and c in ("C1", "C2", "C3") + ]] result = df.loc[(slice("A1", "A3"), slice(None), slice("C1", "C3")), :] tm.assert_frame_equal(result, expected) # test multi-index slicing with per axis and per index controls - index = MultiIndex.from_tuples( - [("A", 1), ("A", 2), ("A", 3), ("B", 1)], names=["one", "two"] - ) + index = MultiIndex.from_tuples([("A", 1), ("A", 2), ("A", 3), + ("B", 1)], + names=["one", "two"]) columns = MultiIndex.from_tuples( [("a", "foo"), ("a", "bar"), ("b", "foo"), ("b", "bah")], names=["lvl0", "lvl1"], ) - df = DataFrame( - np.arange(16, dtype="int64").reshape(4, 4), index=index, columns=columns - ) + df = DataFrame(np.arange(16, dtype="int64").reshape(4, 4), + index=index, + columns=columns) df = df.sort_index(axis=0).sort_index(axis=1) # identity @@ -99,7 +93,10 @@ def test_per_axis_per_level_getitem(self): result = df.loc["A", "a"] expected = DataFrame( - {"bar": [1, 5, 9], "foo": [0, 4, 8]}, + { + "bar": [1, 5, 9], + "foo": [0, 4, 8] + }, index=Index([1, 2, 3], name="two"), columns=Index(["bar", "foo"], name="lvl1"), ) @@ -112,18 +109,13 @@ def test_per_axis_per_level_getitem(self): # multi-level series s = Series(np.arange(len(ix.to_numpy())), index=ix) result = s.loc["A1":"A3", :, ["C1", "C3"]] - expected = s.loc[ - [ - ( - a, - b, - c, - d, - ) - for a, b, c, d in s.index.values - if a in ("A1", "A2", "A3") and c in ("C1", "C3") - ] - ] + expected = s.loc[[( + a, + b, + c, + d, + ) for a, b, c, d in s.index.values + if a in ("A1", "A2", "A3") and c in ("C1", "C3")]] tm.assert_series_equal(result, expected) # boolean indexers @@ -131,10 +123,8 @@ def test_per_axis_per_level_getitem(self): expected = df.iloc[[2, 3]] tm.assert_frame_equal(result, expected) - msg = ( - "cannot index with a boolean indexer " - "that is not the same length as the index" - ) + msg = ("cannot index with a boolean indexer " + "that is not the same length as the index") with pytest.raises(ValueError, match=msg): df.loc[(slice(None), np.array([True, False])), :] @@ -149,10 +139,8 @@ def test_per_axis_per_level_getitem(self): df = df.sort_index(level=1, axis=0) assert df.index._lexsort_depth == 0 - msg = ( - "MultiIndex slicing requires the index to be " - r"lexsorted: slicing on levels \[1\], lexsort depth 0" - ) + msg = ("MultiIndex slicing requires the index to be " + r"lexsorted: slicing on levels \[1\], lexsort depth 0") with pytest.raises(UnsortedIndexError, match=msg): df.loc[(slice(None), slice("bar")), :] @@ -164,24 +152,19 @@ def test_multiindex_slicers_non_unique(self): # GH 7106 # non-unique mi index support - df = ( - DataFrame( - { - "A": ["foo", "foo", "foo", "foo"], - "B": ["a", "a", "a", "a"], - "C": [1, 2, 1, 3], - "D": [1, 2, 3, 4], - } - ) - .set_index(["A", "B", "C"]) - .sort_index() - ) + df = (DataFrame({ + "A": ["foo", "foo", "foo", "foo"], + "B": ["a", "a", "a", "a"], + "C": [1, 2, 1, 3], + "D": [1, 2, 3, 4], + }).set_index(["A", "B", "C"]).sort_index()) assert not df.index.is_unique - expected = ( - DataFrame({"A": ["foo", "foo"], "B": ["a", "a"], "C": [1, 1], "D": [1, 3]}) - .set_index(["A", "B", "C"]) - .sort_index() - ) + expected = (DataFrame({ + "A": ["foo", "foo"], + "B": ["a", "a"], + "C": [1, 1], + "D": [1, 3] + }).set_index(["A", "B", "C"]).sort_index()) result = df.loc[(slice(None), slice(None), 1), :] tm.assert_frame_equal(result, expected) @@ -189,24 +172,19 @@ def test_multiindex_slicers_non_unique(self): result = df.xs(1, level=2, drop_level=False) tm.assert_frame_equal(result, expected) - df = ( - DataFrame( - { - "A": ["foo", "foo", "foo", "foo"], - "B": ["a", "a", "a", "a"], - "C": [1, 2, 1, 2], - "D": [1, 2, 3, 4], - } - ) - .set_index(["A", "B", "C"]) - .sort_index() - ) + df = (DataFrame({ + "A": ["foo", "foo", "foo", "foo"], + "B": ["a", "a", "a", "a"], + "C": [1, 2, 1, 2], + "D": [1, 2, 3, 4], + }).set_index(["A", "B", "C"]).sort_index()) assert not df.index.is_unique - expected = ( - DataFrame({"A": ["foo", "foo"], "B": ["a", "a"], "C": [1, 1], "D": [1, 3]}) - .set_index(["A", "B", "C"]) - .sort_index() - ) + expected = (DataFrame({ + "A": ["foo", "foo"], + "B": ["a", "a"], + "C": [1, 1], + "D": [1, 3] + }).set_index(["A", "B", "C"]).sort_index()) result = df.loc[(slice(None), slice(None), 1), :] assert not result.index.is_unique tm.assert_frame_equal(result, expected) @@ -252,11 +230,12 @@ def test_multiindex_slicers_datetimelike(self): import datetime dates = [ - datetime.datetime(2012, 1, 1, 12, 12, 12) + datetime.timedelta(days=i) - for i in range(6) + datetime.datetime(2012, 1, 1, 12, 12, 12) + + datetime.timedelta(days=i) for i in range(6) ] freq = [1, 2] - index = MultiIndex.from_product([dates, freq], names=["date", "frequency"]) + index = MultiIndex.from_product([dates, freq], + names=["date", "frequency"]) df = DataFrame( np.arange(6 * 2 * 4, dtype="int64").reshape(-1, 4), @@ -267,78 +246,65 @@ def test_multiindex_slicers_datetimelike(self): # multi-axis slicing idx = pd.IndexSlice expected = df.iloc[[0, 2, 4], [0, 1]] - result = df.loc[ - ( - slice( - Timestamp("2012-01-01 12:12:12"), Timestamp("2012-01-03 12:12:12") - ), - slice(1, 1), - ), - slice("A", "B"), - ] + result = df.loc[( + slice(Timestamp("2012-01-01 12:12:12" + ), Timestamp("2012-01-03 12:12:12")), + slice(1, 1), + ), + slice("A", "B"), ] tm.assert_frame_equal(result, expected) - result = df.loc[ - ( - idx[ - Timestamp("2012-01-01 12:12:12") : Timestamp("2012-01-03 12:12:12") - ], - idx[1:1], - ), - slice("A", "B"), - ] + result = df.loc[( + idx[Timestamp("2012-01-01 12:12:12" + ):Timestamp("2012-01-03 12:12:12")], + idx[1:1], + ), + slice("A", "B"), ] tm.assert_frame_equal(result, expected) - result = df.loc[ - ( - slice( - Timestamp("2012-01-01 12:12:12"), Timestamp("2012-01-03 12:12:12") - ), - 1, - ), - slice("A", "B"), - ] + result = df.loc[( + slice(Timestamp("2012-01-01 12:12:12" + ), Timestamp("2012-01-03 12:12:12")), + 1, + ), + slice("A", "B"), ] tm.assert_frame_equal(result, expected) # with strings - result = df.loc[ - (slice("2012-01-01 12:12:12", "2012-01-03 12:12:12"), slice(1, 1)), - slice("A", "B"), - ] + result = df.loc[(slice("2012-01-01 12:12:12", "2012-01-03 12:12:12"), + slice(1, 1)), + slice("A", "B"), ] tm.assert_frame_equal(result, expected) - result = df.loc[ - (idx["2012-01-01 12:12:12":"2012-01-03 12:12:12"], 1), idx["A", "B"] - ] + result = df.loc[(idx["2012-01-01 12:12:12":"2012-01-03 12:12:12"], 1), + idx["A", "B"]] tm.assert_frame_equal(result, expected) def test_multiindex_slicers_edges(self): # GH 8132 # various edge cases - df = DataFrame( - { - "A": ["A0"] * 5 + ["A1"] * 5 + ["A2"] * 5, - "B": ["B0", "B0", "B1", "B1", "B2"] * 3, - "DATE": [ - "2013-06-11", - "2013-07-02", - "2013-07-09", - "2013-07-30", - "2013-08-06", - "2013-06-11", - "2013-07-02", - "2013-07-09", - "2013-07-30", - "2013-08-06", - "2013-09-03", - "2013-10-01", - "2013-07-09", - "2013-08-06", - "2013-09-03", - ], - "VALUES": [22, 35, 14, 9, 4, 40, 18, 4, 2, 5, 1, 2, 3, 4, 2], - } - ) + df = DataFrame({ + "A": ["A0"] * 5 + ["A1"] * 5 + ["A2"] * 5, + "B": ["B0", "B0", "B1", "B1", "B2"] * 3, + "DATE": [ + "2013-06-11", + "2013-07-02", + "2013-07-09", + "2013-07-30", + "2013-08-06", + "2013-06-11", + "2013-07-02", + "2013-07-09", + "2013-07-30", + "2013-08-06", + "2013-09-03", + "2013-10-01", + "2013-07-09", + "2013-08-06", + "2013-09-03", + ], + "VALUES": [22, 35, 14, 9, 4, 40, 18, 4, 2, 5, 1, 2, 3, 4, 2], + }) df["DATE"] = pd.to_datetime(df["DATE"]) df1 = df.set_index(["A", "B", "DATE"]) @@ -360,7 +326,8 @@ def test_multiindex_slicers_edges(self): tm.assert_frame_equal(result, expected) # A4 - Get all values between 2013-07-02 and 2013-07-09 - result = df1.loc[(slice(None), slice(None), slice("20130702", "20130709")), :] + result = df1.loc[(slice(None), slice(None), + slice("20130702", "20130709")), :] expected = df1.iloc[[1, 2, 6, 7, 12]] tm.assert_frame_equal(result, expected) @@ -376,13 +343,15 @@ def test_multiindex_slicers_edges(self): tm.assert_frame_equal(result, expected) # B3 - Get all values from B1 to B2 and up to 2013-08-06 - result = df1.loc[(slice(None), slice("B1", "B2"), slice("2013-08-06")), :] + result = df1.loc[(slice(None), slice("B1", "B2"), + slice("2013-08-06")), :] expected = df1.iloc[[2, 3, 4, 7, 8, 9, 12, 13]] tm.assert_frame_equal(result, expected) # B4 - Same as A4 but the start of the date slice is not a key. # shows indexing on a partial selection slice - result = df1.loc[(slice(None), slice(None), slice("20130701", "20130709")), :] + result = df1.loc[(slice(None), slice(None), + slice("20130701", "20130709")), :] expected = df1.iloc[[1, 2, 6, 7, 12]] tm.assert_frame_equal(result, expected) @@ -393,65 +362,52 @@ def test_per_axis_per_level_doc_examples(self): # from indexing.rst / advanced index = MultiIndex.from_product( - [_mklbl("A", 4), _mklbl("B", 2), _mklbl("C", 4), _mklbl("D", 2)] - ) + [_mklbl("A", 4), + _mklbl("B", 2), + _mklbl("C", 4), + _mklbl("D", 2)]) columns = MultiIndex.from_tuples( [("a", "foo"), ("a", "bar"), ("b", "foo"), ("b", "bah")], names=["lvl0", "lvl1"], ) df = DataFrame( np.arange(len(index) * len(columns), dtype="int64").reshape( - (len(index), len(columns)) - ), + (len(index), len(columns))), index=index, columns=columns, ) result = df.loc[(slice("A1", "A3"), slice(None), ["C1", "C3"]), :] - expected = df.loc[ - [ - ( - a, - b, - c, - d, - ) - for a, b, c, d in df.index.values - if a in ("A1", "A2", "A3") and c in ("C1", "C3") - ] - ] + expected = df.loc[[( + a, + b, + c, + d, + ) for a, b, c, d in df.index.values + if a in ("A1", "A2", "A3") and c in ("C1", "C3")]] tm.assert_frame_equal(result, expected) result = df.loc[idx["A1":"A3", :, ["C1", "C3"]], :] tm.assert_frame_equal(result, expected) result = df.loc[(slice(None), slice(None), ["C1", "C3"]), :] - expected = df.loc[ - [ - ( - a, - b, - c, - d, - ) - for a, b, c, d in df.index.values - if c in ("C1", "C3") - ] - ] + expected = df.loc[[( + a, + b, + c, + d, + ) for a, b, c, d in df.index.values if c in ("C1", "C3")]] tm.assert_frame_equal(result, expected) result = df.loc[idx[:, :, ["C1", "C3"]], :] tm.assert_frame_equal(result, expected) # not sorted - msg = ( - "MultiIndex slicing requires the index to be lexsorted: " - r"slicing on levels \[1\], lexsort depth 1" - ) + msg = ("MultiIndex slicing requires the index to be lexsorted: " + r"slicing on levels \[1\], lexsort depth 1") with pytest.raises(UnsortedIndexError, match=msg): df.loc["A1", ("a", slice("foo"))] # GH 16734: not sorted, but no real slicing - tm.assert_frame_equal( - df.loc["A1", (slice(None), "foo")], df.loc["A1"].iloc[:, [0, 2]] - ) + tm.assert_frame_equal(df.loc["A1", (slice(None), "foo")], + df.loc["A1"].iloc[:, [0, 2]]) df = df.sort_index(axis=1) @@ -465,53 +421,39 @@ def test_per_axis_per_level_doc_examples(self): def test_loc_axis_arguments(self): index = MultiIndex.from_product( - [_mklbl("A", 4), _mklbl("B", 2), _mklbl("C", 4), _mklbl("D", 2)] - ) + [_mklbl("A", 4), + _mklbl("B", 2), + _mklbl("C", 4), + _mklbl("D", 2)]) columns = MultiIndex.from_tuples( [("a", "foo"), ("a", "bar"), ("b", "foo"), ("b", "bah")], names=["lvl0", "lvl1"], ) - df = ( - DataFrame( - np.arange(len(index) * len(columns), dtype="int64").reshape( - (len(index), len(columns)) - ), - index=index, - columns=columns, - ) - .sort_index() - .sort_index(axis=1) - ) + df = (DataFrame( + np.arange(len(index) * len(columns), dtype="int64").reshape( + (len(index), len(columns))), + index=index, + columns=columns, + ).sort_index().sort_index(axis=1)) # axis 0 result = df.loc(axis=0)["A1":"A3", :, ["C1", "C3"]] - expected = df.loc[ - [ - ( - a, - b, - c, - d, - ) - for a, b, c, d in df.index.values - if a in ("A1", "A2", "A3") and c in ("C1", "C3") - ] - ] + expected = df.loc[[( + a, + b, + c, + d, + ) for a, b, c, d in df.index.values + if a in ("A1", "A2", "A3") and c in ("C1", "C3")]] tm.assert_frame_equal(result, expected) result = df.loc(axis="index")[:, :, ["C1", "C3"]] - expected = df.loc[ - [ - ( - a, - b, - c, - d, - ) - for a, b, c, d in df.index.values - if c in ("C1", "C3") - ] - ] + expected = df.loc[[( + a, + b, + c, + d, + ) for a, b, c, d in df.index.values if c in ("C1", "C3")]] tm.assert_frame_equal(result, expected) # axis 1 @@ -534,7 +476,8 @@ def test_loc_axis_single_level_multi_col_indexing_multiindex_col_df(self): # GH29519 df = DataFrame( np.arange(27).reshape(3, 9), - columns=MultiIndex.from_product([["a1", "a2", "a3"], ["b1", "b2", "b3"]]), + columns=MultiIndex.from_product([["a1", "a2", "a3"], + ["b1", "b2", "b3"]]), ) result = df.loc(axis=1)["a1":"a2"] expected = df.iloc[:, :-3] @@ -546,7 +489,8 @@ def test_loc_axis_single_level_single_col_indexing_multiindex_col_df(self): # GH29519 df = DataFrame( np.arange(27).reshape(3, 9), - columns=MultiIndex.from_product([["a1", "a2", "a3"], ["b1", "b2", "b3"]]), + columns=MultiIndex.from_product([["a1", "a2", "a3"], + ["b1", "b2", "b3"]]), ) result = df.loc(axis=1)["a1"] expected = df.iloc[:, :3] @@ -569,17 +513,17 @@ def test_per_axis_per_level_setitem(self): idx = pd.IndexSlice # test multi-index slicing with per axis and per index controls - index = MultiIndex.from_tuples( - [("A", 1), ("A", 2), ("A", 3), ("B", 1)], names=["one", "two"] - ) + index = MultiIndex.from_tuples([("A", 1), ("A", 2), ("A", 3), + ("B", 1)], + names=["one", "two"]) columns = MultiIndex.from_tuples( [("a", "foo"), ("a", "bar"), ("b", "foo"), ("b", "bah")], names=["lvl0", "lvl1"], ) - df_orig = DataFrame( - np.arange(16, dtype="int64").reshape(4, 4), index=index, columns=columns - ) + df_orig = DataFrame(np.arange(16, dtype="int64").reshape(4, 4), + index=index, + columns=columns) df_orig = df_orig.sort_index(axis=0).sort_index(axis=1) # identity @@ -654,9 +598,9 @@ def test_per_axis_per_level_setitem(self): # setting with a list-like df = df_orig.copy() - df.loc[(slice(None), 1), (slice(None), ["foo"])] = np.array( - [[100, 100], [100, 100]], dtype="int64" - ) + df.loc[(slice(None), 1), + (slice(None), ["foo"])] = np.array([[100, 100], [100, 100]], + dtype="int64") expected = df_orig.copy() expected.iloc[[0, 3], [1, 3]] = 100 tm.assert_frame_equal(df, expected) @@ -666,29 +610,29 @@ def test_per_axis_per_level_setitem(self): msg = "setting an array element with a sequence." with pytest.raises(ValueError, match=msg): - df.loc[(slice(None), 1), (slice(None), ["foo"])] = np.array( - [[100], [100, 100]], dtype="int64" - ) + df.loc[(slice(None), 1), + (slice(None), ["foo"])] = np.array([[100], [100, 100]], + dtype="int64") msg = "Must have equal len keys and value when setting with an iterable" with pytest.raises(ValueError, match=msg): - df.loc[(slice(None), 1), (slice(None), ["foo"])] = np.array( - [100, 100, 100, 100], dtype="int64" - ) + df.loc[(slice(None), 1), + (slice(None), ["foo"])] = np.array([100, 100, 100, 100], + dtype="int64") # with an alignable rhs df = df_orig.copy() - df.loc[(slice(None), 1), (slice(None), ["foo"])] = ( - df.loc[(slice(None), 1), (slice(None), ["foo"])] * 5 - ) + df.loc[(slice(None), 1), + (slice(None), ["foo"])] = (df.loc[(slice(None), 1), + (slice(None), ["foo"])] * 5) expected = df_orig.copy() expected.iloc[[0, 3], [1, 3]] = expected.iloc[[0, 3], [1, 3]] * 5 tm.assert_frame_equal(df, expected) df = df_orig.copy() - df.loc[(slice(None), 1), (slice(None), ["foo"])] *= df.loc[ - (slice(None), 1), (slice(None), ["foo"]) - ] + df.loc[(slice(None), 1), + (slice(None), ["foo"])] *= df.loc[(slice(None), 1), + (slice(None), ["foo"])] expected = df_orig.copy() expected.iloc[[0, 3], [1, 3]] *= expected.iloc[[0, 3], [1, 3]] tm.assert_frame_equal(df, expected) @@ -702,30 +646,37 @@ def test_per_axis_per_level_setitem(self): tm.assert_frame_equal(df, expected) def test_multiindex_label_slicing_with_negative_step(self): - ser = Series( - np.arange(20), MultiIndex.from_product([list("abcde"), np.arange(4)]) - ) + ser = Series(np.arange(20), + MultiIndex.from_product([list("abcde"), + np.arange(4)])) SLC = pd.IndexSlice tm.assert_indexing_slices_equivalent(ser, SLC[::-1], SLC[::-1]) tm.assert_indexing_slices_equivalent(ser, SLC["d"::-1], SLC[15::-1]) - tm.assert_indexing_slices_equivalent(ser, SLC[("d",)::-1], SLC[15::-1]) + tm.assert_indexing_slices_equivalent(ser, SLC[("d", )::-1], + SLC[15::-1]) tm.assert_indexing_slices_equivalent(ser, SLC[:"d":-1], SLC[:11:-1]) - tm.assert_indexing_slices_equivalent(ser, SLC[:("d",):-1], SLC[:11:-1]) - - tm.assert_indexing_slices_equivalent(ser, SLC["d":"b":-1], SLC[15:3:-1]) - tm.assert_indexing_slices_equivalent(ser, SLC[("d",):"b":-1], SLC[15:3:-1]) - tm.assert_indexing_slices_equivalent(ser, SLC["d":("b",):-1], SLC[15:3:-1]) - tm.assert_indexing_slices_equivalent(ser, SLC[("d",):("b",):-1], SLC[15:3:-1]) + tm.assert_indexing_slices_equivalent(ser, SLC[:("d", ):-1], + SLC[:11:-1]) + + tm.assert_indexing_slices_equivalent(ser, SLC["d":"b":-1], + SLC[15:3:-1]) + tm.assert_indexing_slices_equivalent(ser, SLC[("d", ):"b":-1], + SLC[15:3:-1]) + tm.assert_indexing_slices_equivalent(ser, SLC["d":("b", ):-1], + SLC[15:3:-1]) + tm.assert_indexing_slices_equivalent(ser, SLC[("d", ):("b", ):-1], + SLC[15:3:-1]) tm.assert_indexing_slices_equivalent(ser, SLC["b":"d":-1], SLC[:0]) - tm.assert_indexing_slices_equivalent(ser, SLC[("c", 2)::-1], SLC[10::-1]) - tm.assert_indexing_slices_equivalent(ser, SLC[:("c", 2):-1], SLC[:9:-1]) - tm.assert_indexing_slices_equivalent( - ser, SLC[("e", 0):("c", 2):-1], SLC[16:9:-1] - ) + tm.assert_indexing_slices_equivalent(ser, SLC[("c", 2)::-1], + SLC[10::-1]) + tm.assert_indexing_slices_equivalent(ser, SLC[:("c", 2):-1], + SLC[:9:-1]) + tm.assert_indexing_slices_equivalent(ser, SLC[("e", 0):("c", 2):-1], + SLC[16:9:-1]) def test_multiindex_slice_first_level(self): # GH 12697 @@ -734,15 +685,18 @@ def test_multiindex_slice_first_level(self): df = DataFrame(list(range(2000)), index=idx, columns=["Test"]) df_slice = df.loc[pd.IndexSlice[:, 30:70], :] result = df_slice.loc["a"] - expected = DataFrame(list(range(30, 71)), columns=["Test"], index=range(30, 71)) + expected = DataFrame(list(range(30, 71)), + columns=["Test"], + index=range(30, 71)) tm.assert_frame_equal(result, expected) result = df_slice.loc["d"] - expected = DataFrame( - list(range(1530, 1571)), columns=["Test"], index=range(30, 71) - ) + expected = DataFrame(list(range(1530, 1571)), + columns=["Test"], + index=range(30, 71)) tm.assert_frame_equal(result, expected) - def test_int_series_slicing(self, multiindex_year_month_day_dataframe_random_data): + def test_int_series_slicing( + self, multiindex_year_month_day_dataframe_random_data): ymd = multiindex_year_month_day_dataframe_random_data s = ymd["A"] result = s[5:] @@ -763,7 +717,7 @@ def test_loc_slice_negative_stepsize(self): mi = MultiIndex.from_product([["a", "b"], [0, 1]]) df = DataFrame([[1, 2], [3, 4], [5, 6], [7, 8]], index=mi) result = df.loc[("a", slice(None, None, -1)), :] - expected = DataFrame( - [[3, 4], [1, 2]], index=MultiIndex.from_tuples([("a", 1), ("a", 0)]) - ) + expected = DataFrame([[3, 4], [1, 2]], + index=MultiIndex.from_tuples([("a", 1), + ("a", 0)])) tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/io/json/test_pandas.py b/pandas/tests/io/json/test_pandas.py index bf8d612d8a1b2..3b3cb539b44cb 100644 --- a/pandas/tests/io/json/test_pandas.py +++ b/pandas/tests/io/json/test_pandas.py @@ -35,17 +35,19 @@ def assert_json_roundtrip_equal(result, expected, orient): @pytest.mark.filterwarnings( - "ignore:an integer is required (got type float)*:DeprecationWarning" -) -@pytest.mark.filterwarnings("ignore:the 'numpy' keyword is deprecated:FutureWarning") + "ignore:an integer is required (got type float)*:DeprecationWarning") +@pytest.mark.filterwarnings( + "ignore:the 'numpy' keyword is deprecated:FutureWarning") class TestPandasContainer: + @pytest.fixture def categorical_frame(self): _seriesd = tm.getSeriesData() _cat_frame = DataFrame(_seriesd) - cat = ["bah"] * 5 + ["bar"] * 5 + ["baz"] * 5 + ["foo"] * (len(_cat_frame) - 15) + cat = ["bah"] * 5 + ["bar"] * 5 + ["baz"] * 5 + ["foo"] * ( + len(_cat_frame) - 15) _cat_frame.index = pd.CategoricalIndex(cat, name="E") _cat_frame["E"] = list(reversed(cat)) _cat_frame["sort"] = np.arange(len(_cat_frame), dtype="int64") @@ -82,7 +84,9 @@ def test_frame_double_encoded_labels(self, orient): @pytest.mark.parametrize("orient", ["split", "records", "values"]) def test_frame_non_unique_index(self, orient): - df = DataFrame([["a", "b"], ["c", "d"]], index=[1, 1], columns=["x", "y"]) + df = DataFrame([["a", "b"], ["c", "d"]], + index=[1, 1], + columns=["x", "y"]) result = read_json(df.to_json(orient=orient), orient=orient) expected = df.copy() @@ -90,7 +94,9 @@ def test_frame_non_unique_index(self, orient): @pytest.mark.parametrize("orient", ["index", "columns"]) def test_frame_non_unique_index_raises(self, orient): - df = DataFrame([["a", "b"], ["c", "d"]], index=[1, 1], columns=["x", "y"]) + df = DataFrame([["a", "b"], ["c", "d"]], + index=[1, 1], + columns=["x", "y"]) msg = f"DataFrame index must be unique for orient='{orient}'" with pytest.raises(ValueError, match=msg): df.to_json(orient=orient) @@ -108,9 +114,9 @@ def test_frame_non_unique_index_raises(self, orient): def test_frame_non_unique_columns(self, orient, data): df = DataFrame(data, index=[1, 2], columns=["x", "x"]) - result = read_json( - df.to_json(orient=orient), orient=orient, convert_dates=["x"] - ) + result = read_json(df.to_json(orient=orient), + orient=orient, + convert_dates=["x"]) if orient == "values": expected = DataFrame(data) if expected.iloc[:, 0].dtype == "datetime64[ns]": @@ -118,7 +124,8 @@ def test_frame_non_unique_columns(self, orient, data): # in milliseconds; these are internally stored in nanosecond, # so divide to get where we need # TODO: a to_epoch method would also solve; see GH 14772 - expected.iloc[:, 0] = expected.iloc[:, 0].view(np.int64) // 1000000 + expected.iloc[:, 0] = expected.iloc[:, 0].view( + np.int64) // 1000000 elif orient == "split": expected = df @@ -126,7 +133,9 @@ def test_frame_non_unique_columns(self, orient, data): @pytest.mark.parametrize("orient", ["index", "columns", "records"]) def test_frame_non_unique_columns_raises(self, orient): - df = DataFrame([["a", "b"], ["c", "d"]], index=[1, 2], columns=["x", "x"]) + df = DataFrame([["a", "b"], ["c", "d"]], + index=[1, 2], + columns=["x", "x"]) msg = f"DataFrame columns must be unique for orient='{orient}'" with pytest.raises(ValueError, match=msg): @@ -138,11 +147,14 @@ def test_frame_default_orient(self, float_frame): @pytest.mark.parametrize("dtype", [False, float]) @pytest.mark.parametrize("convert_axes", [True, False]) @pytest.mark.parametrize("numpy", [True, False]) - def test_roundtrip_simple(self, orient, convert_axes, numpy, dtype, float_frame): + def test_roundtrip_simple(self, orient, convert_axes, numpy, dtype, + float_frame): data = float_frame.to_json(orient=orient) - result = read_json( - data, orient=orient, convert_axes=convert_axes, numpy=numpy, dtype=dtype - ) + result = read_json(data, + orient=orient, + convert_axes=convert_axes, + numpy=numpy, + dtype=dtype) expected = float_frame @@ -151,18 +163,17 @@ def test_roundtrip_simple(self, orient, convert_axes, numpy, dtype, float_frame) @pytest.mark.parametrize("dtype", [False, np.int64]) @pytest.mark.parametrize("convert_axes", [True, False]) @pytest.mark.parametrize("numpy", [True, False]) - def test_roundtrip_intframe(self, orient, convert_axes, numpy, dtype, int_frame): + def test_roundtrip_intframe(self, orient, convert_axes, numpy, dtype, + int_frame): data = int_frame.to_json(orient=orient) - result = read_json( - data, orient=orient, convert_axes=convert_axes, numpy=numpy, dtype=dtype - ) + result = read_json(data, + orient=orient, + convert_axes=convert_axes, + numpy=numpy, + dtype=dtype) expected = int_frame - if ( - numpy - and (not IS64 or is_platform_windows()) - and not dtype - and orient != "split" - ): + if (numpy and (not IS64 or is_platform_windows()) and not dtype + and orient != "split"): # TODO: see what is causing roundtrip dtype loss expected = expected.astype(np.int32) @@ -171,7 +182,8 @@ def test_roundtrip_intframe(self, orient, convert_axes, numpy, dtype, int_frame) @pytest.mark.parametrize("dtype", [None, np.float64, int, "U3"]) @pytest.mark.parametrize("convert_axes", [True, False]) @pytest.mark.parametrize("numpy", [True, False]) - def test_roundtrip_str_axes(self, request, orient, convert_axes, numpy, dtype): + def test_roundtrip_str_axes(self, request, orient, convert_axes, numpy, + dtype): df = DataFrame( np.zeros((200, 4)), columns=[str(i) for i in range(4)], @@ -182,13 +194,14 @@ def test_roundtrip_str_axes(self, request, orient, convert_axes, numpy, dtype): # TODO: do we even need to support U3 dtypes? if numpy and dtype == "U3" and orient != "split": request.node.add_marker( - pytest.mark.xfail(reason="Can't decode directly to array") - ) + pytest.mark.xfail(reason="Can't decode directly to array")) data = df.to_json(orient=orient) - result = read_json( - data, orient=orient, convert_axes=convert_axes, numpy=numpy, dtype=dtype - ) + result = read_json(data, + orient=orient, + convert_axes=convert_axes, + numpy=numpy, + dtype=dtype) expected = df.copy() if not dtype: @@ -211,27 +224,29 @@ def test_roundtrip_str_axes(self, request, orient, convert_axes, numpy, dtype): @pytest.mark.parametrize("convert_axes", [True, False]) @pytest.mark.parametrize("numpy", [True, False]) - def test_roundtrip_categorical( - self, request, orient, categorical_frame, convert_axes, numpy - ): + def test_roundtrip_categorical(self, request, orient, categorical_frame, + convert_axes, numpy): # TODO: create a better frame to test with and improve coverage if orient in ("index", "columns"): request.node.add_marker( pytest.mark.xfail( reason=f"Can't have duplicate index values for orient '{orient}')" - ) - ) + )) data = categorical_frame.to_json(orient=orient) if numpy and orient in ("records", "values"): request.node.add_marker( - pytest.mark.xfail(reason=f"Orient {orient} is broken with numpy=True") - ) + pytest.mark.xfail( + reason=f"Orient {orient} is broken with numpy=True")) - result = read_json(data, orient=orient, convert_axes=convert_axes, numpy=numpy) + result = read_json(data, + orient=orient, + convert_axes=convert_axes, + numpy=numpy) expected = categorical_frame.copy() - expected.index = expected.index.astype(str) # Categorical not preserved + expected.index = expected.index.astype( + str) # Categorical not preserved expected.index.name = None # index names aren't preserved in JSON if not numpy and orient == "index": @@ -244,7 +259,10 @@ def test_roundtrip_categorical( def test_roundtrip_empty(self, orient, convert_axes, numpy): empty_frame = DataFrame() data = empty_frame.to_json(orient=orient) - result = read_json(data, orient=orient, convert_axes=convert_axes, numpy=numpy) + result = read_json(data, + orient=orient, + convert_axes=convert_axes, + numpy=numpy) expected = empty_frame.copy() # TODO: both conditions below are probably bugs @@ -258,10 +276,14 @@ def test_roundtrip_empty(self, orient, convert_axes, numpy): @pytest.mark.parametrize("convert_axes", [True, False]) @pytest.mark.parametrize("numpy", [True, False]) - def test_roundtrip_timestamp(self, orient, convert_axes, numpy, datetime_frame): + def test_roundtrip_timestamp(self, orient, convert_axes, numpy, + datetime_frame): # TODO: improve coverage with date_format parameter data = datetime_frame.to_json(orient=orient) - result = read_json(data, orient=orient, convert_axes=convert_axes, numpy=numpy) + result = read_json(data, + orient=orient, + convert_axes=convert_axes, + numpy=numpy) expected = datetime_frame.copy() if not convert_axes: # one off for ts handling @@ -279,8 +301,7 @@ def test_roundtrip_timestamp(self, orient, convert_axes, numpy, datetime_frame): def test_roundtrip_mixed(self, request, orient, convert_axes, numpy): if numpy and orient != "split": request.node.add_marker( - pytest.mark.xfail(reason="Can't decode directly to array") - ) + pytest.mark.xfail(reason="Can't decode directly to array")) index = pd.Index(["a", "b", "c", "d", "e"]) values = { @@ -293,10 +314,14 @@ def test_roundtrip_mixed(self, request, orient, convert_axes, numpy): df = DataFrame(data=values, index=index) data = df.to_json(orient=orient) - result = read_json(data, orient=orient, convert_axes=convert_axes, numpy=numpy) + result = read_json(data, + orient=orient, + convert_axes=convert_axes, + numpy=numpy) expected = df.copy() - expected = expected.assign(**expected.select_dtypes("number").astype(np.int64)) + expected = expected.assign( + **expected.select_dtypes("number").astype(np.int64)) if not numpy and orient == "index": expected = expected.sort_index() @@ -312,11 +337,9 @@ def test_roundtrip_mixed(self, request, orient, convert_axes, numpy): '{"columns":["A","B"],' '"index":["2","3"],' '"data":[[1.0,"1"],[2.0,"2"],[null,"3"]]}', - "|".join( - [ - r"Length of values \(3\) does not match length of index \(2\)", - ] - ), + "|".join([ + r"Length of values \(3\) does not match length of index \(2\)", + ]), "split", ), # too many columns @@ -344,7 +367,8 @@ def test_frame_from_json_bad_data_raises(self, data, msg, orient): @pytest.mark.parametrize("dtype", [True, False]) @pytest.mark.parametrize("convert_axes", [True, False]) @pytest.mark.parametrize("numpy", [True, False]) - def test_frame_from_json_missing_data(self, orient, convert_axes, numpy, dtype): + def test_frame_from_json_missing_data(self, orient, convert_axes, numpy, + dtype): num_df = DataFrame([[1, 2], [4, 5, 6]]) result = read_json( num_df.to_json(orient=orient), @@ -382,7 +406,8 @@ def test_frame_infinity(self, inf, dtype): result = read_json(df.to_json(), dtype=dtype) assert np.isnan(result.iloc[0, 2]) - @pytest.mark.skipif(not IS64, reason="not compliant on 32-bit, xref #15865") + @pytest.mark.skipif(not IS64, + reason="not compliant on 32-bit, xref #15865") @pytest.mark.parametrize( "value,precision,expected_val", [ @@ -394,7 +419,8 @@ def test_frame_infinity(self, inf, dtype): (0.99999999999999944, 15, 1.0), ], ) - def test_frame_to_json_float_precision(self, value, precision, expected_val): + def test_frame_to_json_float_precision(self, value, precision, + expected_val): df = DataFrame([{"a_float": value}]) encoded = df.to_json(double_precision=precision) assert encoded == f'{{"a_float":{{"0":{expected_val}}}}}' @@ -408,9 +434,9 @@ def test_frame_to_json_except(self): def test_frame_empty(self): df = DataFrame(columns=["jim", "joe"]) assert not df._is_mixed_type - tm.assert_frame_equal( - read_json(df.to_json(), dtype=dict(df.dtypes)), df, check_index_type=False - ) + tm.assert_frame_equal(read_json(df.to_json(), dtype=dict(df.dtypes)), + df, + check_index_type=False) # GH 7445 result = DataFrame({"test": []}, index=[]).to_json(orient="columns") expected = '{"test":{}}' @@ -421,9 +447,9 @@ def test_frame_empty_mixedtype(self): df = DataFrame(columns=["jim", "joe"]) df["joe"] = df["joe"].astype("i8") assert df._is_mixed_type - tm.assert_frame_equal( - read_json(df.to_json(), dtype=dict(df.dtypes)), df, check_index_type=False - ) + tm.assert_frame_equal(read_json(df.to_json(), dtype=dict(df.dtypes)), + df, + check_index_type=False) def test_frame_mixedtype_orient(self): # GH10289 vals = [ @@ -433,9 +459,9 @@ def test_frame_mixedtype_orient(self): # GH10289 [40, 4, "qux", 0.4, 0.04], ] - df = DataFrame( - vals, index=list("abcd"), columns=["1st", "2nd", "3rd", "4th", "5th"] - ) + df = DataFrame(vals, + index=list("abcd"), + columns=["1st", "2nd", "3rd", "4th", "5th"]) assert df._is_mixed_type right = df.copy() @@ -571,7 +597,8 @@ def test_blocks_compat_GH9037(self): # JSON deserialisation always creates unicode strings df_mixed.columns = df_mixed.columns.astype("unicode") - df_roundtrip = read_json(df_mixed.to_json(orient="split"), orient="split") + df_roundtrip = read_json(df_mixed.to_json(orient="split"), + orient="split") tm.assert_frame_equal( df_mixed, df_roundtrip, @@ -585,6 +612,7 @@ def test_frame_nonprintable_bytes(self): # GH14256: failing column caused segfaults, if it is not the last one class BinaryThing: + def __init__(self, hexed): self.hexed = hexed self.binary = bytes.fromhex(hexed) @@ -614,10 +642,8 @@ def __str__(self) -> str: result = df_nonprintable.to_json(default_handler=str) expected = f'{{"A":{{"0":"{hexed}"}}}}' assert result == expected - assert ( - df_mixed.to_json(default_handler=str) - == f'{{"A":{{"0":"{hexed}"}},"B":{{"0":1}}}}' - ) + assert (df_mixed.to_json( + default_handler=str) == f'{{"A":{{"0":"{hexed}"}},"B":{{"0":1}}}}') def test_label_overflow(self): # GH14256: buffer length not checked when writing label @@ -633,11 +659,11 @@ def test_series_non_unique_index(self): s.to_json(orient="index") tm.assert_series_equal( - s, read_json(s.to_json(orient="split"), orient="split", typ="series") - ) - unserialized = read_json( - s.to_json(orient="records"), orient="records", typ="series" - ) + s, + read_json(s.to_json(orient="split"), orient="split", typ="series")) + unserialized = read_json(s.to_json(orient="records"), + orient="records", + typ="series") tm.assert_numpy_array_equal(s.values, unserialized.values) def test_series_default_orient(self, string_series): @@ -658,9 +684,14 @@ def test_series_roundtrip_simple(self, orient, numpy, string_series): @pytest.mark.parametrize("dtype", [False, None]) @pytest.mark.parametrize("numpy", [True, False]) - def test_series_roundtrip_object(self, orient, numpy, dtype, object_series): + def test_series_roundtrip_object(self, orient, numpy, dtype, + object_series): data = object_series.to_json(orient=orient) - result = read_json(data, typ="series", orient=orient, numpy=numpy, dtype=dtype) + result = read_json(data, + typ="series", + orient=orient, + numpy=numpy, + dtype=dtype) expected = object_series if orient in ("values", "records"): @@ -748,7 +779,9 @@ def test_frame_from_json_precise_float(self): def test_typ(self): - s = Series(range(6), index=["a", "b", "c", "d", "e", "f"], dtype="int64") + s = Series(range(6), + index=["a", "b", "c", "d", "e", "f"], + dtype="int64") result = read_json(s.to_json(), typ=None) tm.assert_series_equal(result, s) @@ -809,7 +842,8 @@ def test_convert_dates(self, datetime_series, datetime_frame): @pytest.mark.parametrize("date_format", ["epoch", "iso"]) @pytest.mark.parametrize("as_object", [True, False]) - @pytest.mark.parametrize("date_typ", [datetime.date, datetime.datetime, Timestamp]) + @pytest.mark.parametrize("date_typ", + [datetime.date, datetime.datetime, Timestamp]) def test_date_index_and_values(self, date_format, as_object, date_typ): data = [date_typ(year=2020, month=1, day=1), pd.NaT] if as_object: @@ -847,9 +881,8 @@ def test_convert_dates_infer(self, infer_word): from pandas.io.json import dumps data = [{"id": 1, infer_word: 1036713600000}, {"id": 2}] - expected = DataFrame( - [[1, Timestamp("2002-11-08")], [2, pd.NaT]], columns=["id", infer_word] - ) + expected = DataFrame([[1, Timestamp("2002-11-08")], [2, pd.NaT]], + columns=["id", infer_word]) result = read_json(dumps(data))[["id", infer_word]] tm.assert_frame_equal(result, expected) @@ -910,7 +943,8 @@ def test_date_format_series(self, date, date_unit, datetime_series): tm.assert_series_equal(result, expected) def test_date_format_series_raises(self, datetime_series): - ts = Series(Timestamp("20130101 20:43:42.123"), index=datetime_series.index) + ts = Series(Timestamp("20130101 20:43:42.123"), + index=datetime_series.index) msg = "Invalid value 'foo' for option 'date_unit'" with pytest.raises(ValueError, match=msg): ts.to_json(date_format="iso", date_unit="foo") @@ -991,7 +1025,8 @@ def test_round_trip_exception_(self, datapath): df = pd.read_csv(path) s = df.to_json() result = read_json(s) - tm.assert_frame_equal(result.reindex(index=df.index, columns=df.columns), df) + tm.assert_frame_equal( + result.reindex(index=df.index, columns=df.columns), df) @pytest.mark.network @tm.network( @@ -1020,22 +1055,22 @@ def test_timedelta(self): result = read_json(s.to_json(), typ="series").apply(converter) tm.assert_series_equal(result, s) - s = Series([timedelta(23), timedelta(seconds=5)], index=pd.Index([0, 1])) + s = Series([timedelta(23), timedelta(seconds=5)], + index=pd.Index([0, 1])) assert s.dtype == "timedelta64[ns]" result = read_json(s.to_json(), typ="series").apply(converter) tm.assert_series_equal(result, s) frame = DataFrame([timedelta(23), timedelta(seconds=5)]) assert frame[0].dtype == "timedelta64[ns]" - tm.assert_frame_equal(frame, read_json(frame.to_json()).apply(converter)) + tm.assert_frame_equal(frame, + read_json(frame.to_json()).apply(converter)) - frame = DataFrame( - { - "a": [timedelta(days=23), timedelta(seconds=5)], - "b": [1, 2], - "c": pd.date_range(start="20130101", periods=2), - } - ) + frame = DataFrame({ + "a": [timedelta(days=23), timedelta(seconds=5)], + "b": [1, 2], + "c": pd.date_range(start="20130101", periods=2), + }) result = read_json(frame.to_json(date_unit="ns")) result["a"] = pd.to_timedelta(result.a, unit="ns") @@ -1043,11 +1078,13 @@ def test_timedelta(self): tm.assert_frame_equal(frame, result) def test_mixed_timedelta_datetime(self): - frame = DataFrame({"a": [timedelta(23), Timestamp("20130101")]}, dtype=object) + frame = DataFrame( + {"a": [timedelta(23), Timestamp("20130101")]}, dtype=object) - expected = DataFrame( - {"a": [pd.Timedelta(frame.a[0]).value, Timestamp(frame.a[1]).value]} - ) + expected = DataFrame({ + "a": [pd.Timedelta(frame.a[0]).value, + Timestamp(frame.a[1]).value] + }) result = read_json(frame.to_json(date_unit="ns"), dtype={"a": "int64"}) tm.assert_frame_equal(result, expected, check_index_type=False) @@ -1086,33 +1123,37 @@ def test_default_handler_indirect(self): def default(obj): if isinstance(obj, complex): - return [("mathjs", "Complex"), ("re", obj.real), ("im", obj.imag)] + return [("mathjs", "Complex"), ("re", obj.real), + ("im", obj.imag)] return str(obj) df_list = [ 9, DataFrame( - {"a": [1, "STR", complex(4, -5)], "b": [float("nan"), None, "N/A"]}, + { + "a": [1, "STR", complex(4, -5)], + "b": [float("nan"), None, "N/A"] + }, columns=["a", "b"], ), ] - expected = ( - '[9,[[1,null],["STR",null],[[["mathjs","Complex"],' - '["re",4.0],["im",-5.0]],"N\\/A"]]]' - ) - assert dumps(df_list, default_handler=default, orient="values") == expected + expected = ('[9,[[1,null],["STR",null],[[["mathjs","Complex"],' + '["re",4.0],["im",-5.0]],"N\\/A"]]]') + assert dumps(df_list, default_handler=default, + orient="values") == expected def test_default_handler_numpy_unsupported_dtype(self): # GH12554 to_json raises 'Unhandled numpy dtype 15' df = DataFrame( - {"a": [1, 2.3, complex(4, -5)], "b": [float("nan"), None, complex(1.2, 0)]}, + { + "a": [1, 2.3, complex(4, -5)], + "b": [float("nan"), None, complex(1.2, 0)] + }, columns=["a", "b"], ) - expected = ( - '[["(1+0j)","(nan+0j)"],' - '["(2.3+0j)","(nan+0j)"],' - '["(4-5j)","(1.2+0j)"]]' - ) + expected = ('[["(1+0j)","(nan+0j)"],' + '["(2.3+0j)","(nan+0j)"],' + '["(4-5j)","(1.2+0j)"]]') assert df.to_json(default_handler=str, orient="values") == expected def test_default_handler_raises(self): @@ -1122,13 +1163,13 @@ def my_handler_raises(obj): raise TypeError(msg) with pytest.raises(TypeError, match=msg): - DataFrame({"a": [1, 2, object()]}).to_json( - default_handler=my_handler_raises - ) + DataFrame({ + "a": [1, 2, object()] + }).to_json(default_handler=my_handler_raises) with pytest.raises(TypeError, match=msg): - DataFrame({"a": [1, 2, complex(4, -5)]}).to_json( - default_handler=my_handler_raises - ) + DataFrame({ + "a": [1, 2, complex(4, -5)] + }).to_json(default_handler=my_handler_raises) def test_categorical(self): # GH4377 df.to_json segfaults with non-ndarray blocks @@ -1148,7 +1189,10 @@ def test_datetime_tz(self): tz_range = pd.date_range("20130101", periods=3, tz="US/Eastern") tz_naive = tz_range.tz_convert("utc").tz_localize(None) - df = DataFrame({"A": tz_range, "B": pd.date_range("20130101", periods=3)}) + df = DataFrame({ + "A": tz_range, + "B": pd.date_range("20130101", periods=3) + }) df_naive = df.copy() df_naive["A"] = tz_naive @@ -1204,11 +1248,9 @@ def test_tz_range_is_utc(self, tz_range): from pandas.io.json import dumps exp = '["2013-01-01T05:00:00.000Z","2013-01-02T05:00:00.000Z"]' - dfexp = ( - '{"DT":{' - '"0":"2013-01-01T05:00:00.000Z",' - '"1":"2013-01-02T05:00:00.000Z"}}' - ) + dfexp = ('{"DT":{' + '"0":"2013-01-01T05:00:00.000Z",' + '"1":"2013-01-02T05:00:00.000Z"}}') assert dumps(tz_range, iso_dates=True) == exp dti = DatetimeIndex(tz_range) @@ -1228,9 +1270,9 @@ def test_read_inline_jsonl(self): def test_read_s3_jsonl(self, s3_resource, s3so): # GH17200 - result = read_json( - "s3n://pandas-test/items.jsonl", lines=True, storage_options=s3so - ) + result = read_json("s3n://pandas-test/items.jsonl", + lines=True, + storage_options=s3so) expected = DataFrame([[1, 2], [1, 2]], columns=["a", "b"]) tm.assert_frame_equal(result, expected) @@ -1251,13 +1293,15 @@ def test_read_jsonl_unicode_chars(self): json = '{"a": "foo”", "b": "bar"}\n{"a": "foo", "b": "bar"}\n' json = StringIO(json) result = read_json(json, lines=True) - expected = DataFrame([["foo\u201d", "bar"], ["foo", "bar"]], columns=["a", "b"]) + expected = DataFrame([["foo\u201d", "bar"], ["foo", "bar"]], + columns=["a", "b"]) tm.assert_frame_equal(result, expected) # simulate string json = '{"a": "foo”", "b": "bar"}\n{"a": "foo", "b": "bar"}\n' result = read_json(json, lines=True) - expected = DataFrame([["foo\u201d", "bar"], ["foo", "bar"]], columns=["a", "b"]) + expected = DataFrame([["foo\u201d", "bar"], ["foo", "bar"]], + columns=["a", "b"]) tm.assert_frame_equal(result, expected) @pytest.mark.parametrize("bigNum", [sys.maxsize + 1, -(sys.maxsize + 2)]) @@ -1313,14 +1357,16 @@ def test_to_jsonl(self): tm.assert_frame_equal(read_json(result, lines=True), df) # GH15096: escaped characters in columns and data - df = DataFrame([["foo\\", "bar"], ['foo"', "bar"]], columns=["a\\", "b"]) + df = DataFrame([["foo\\", "bar"], ['foo"', "bar"]], + columns=["a\\", "b"]) result = df.to_json(orient="records", lines=True) expected = '{"a\\\\":"foo\\\\","b":"bar"}\n{"a\\\\":"foo\\"","b":"bar"}\n' assert result == expected tm.assert_frame_equal(read_json(result, lines=True), df) # TODO: there is a near-identical test for pytables; can we share? - @pytest.mark.xfail(reason="GH#13774 encoding kwarg not supported", raises=TypeError) + @pytest.mark.xfail(reason="GH#13774 encoding kwarg not supported", + raises=TypeError) def test_latin_encoding(self): # GH 13774 values = [ @@ -1335,10 +1381,9 @@ def test_latin_encoding(self): [b"A\xf8\xfc", np.nan, b"", b"b", b"c"], ] - values = [ - [x.decode("latin-1") if isinstance(x, bytes) else x for x in y] - for y in values - ] + values = [[ + x.decode("latin-1") if isinstance(x, bytes) else x for x in y + ] for y in values] examples = [] for dtype in ["category", object]: @@ -1365,8 +1410,8 @@ def test_data_frame_size_after_to_json(self): assert size_before == size_after @pytest.mark.parametrize( - "index", [None, [1, 2], [1.0, 2.0], ["a", "b"], ["1", "2"], ["1.", "2."]] - ) + "index", + [None, [1, 2], [1.0, 2.0], ["a", "b"], ["1", "2"], ["1.", "2."]]) @pytest.mark.parametrize("columns", [["a", "b"], ["1", "2"], ["1.", "2."]]) def test_from_json_to_json_table_index_and_columns(self, index, columns): # GH25433 GH25435 @@ -1382,21 +1427,24 @@ def test_from_json_to_json_table_dtypes(self): result = read_json(dfjson, orient="table") tm.assert_frame_equal(result, expected) - @pytest.mark.parametrize("orient", ["split", "records", "index", "columns"]) + @pytest.mark.parametrize("orient", + ["split", "records", "index", "columns"]) def test_to_json_from_json_columns_dtypes(self, orient): # GH21892 GH33205 - expected = DataFrame.from_dict( - { - "Integer": Series([1, 2, 3], dtype="int64"), - "Float": Series([None, 2.0, 3.0], dtype="float64"), - "Object": Series([None, "", "c"], dtype="object"), - "Bool": Series([True, False, True], dtype="bool"), - "Category": Series(["a", "b", None], dtype="category"), - "Datetime": Series( - ["2020-01-01", None, "2020-01-03"], dtype="datetime64[ns]" - ), - } - ) + expected = DataFrame.from_dict({ + "Integer": + Series([1, 2, 3], dtype="int64"), + "Float": + Series([None, 2.0, 3.0], dtype="float64"), + "Object": + Series([None, "", "c"], dtype="object"), + "Bool": + Series([True, False, True], dtype="bool"), + "Category": + Series(["a", "b", None], dtype="category"), + "Datetime": + Series(["2020-01-01", None, "2020-01-03"], dtype="datetime64[ns]"), + }) dfjson = expected.to_json(orient=orient) result = read_json( dfjson, @@ -1423,7 +1471,9 @@ def test_read_json_table_dtype_raises(self, dtype): def test_read_json_table_convert_axes_raises(self): # GH25433 GH25435 - df = DataFrame([[1, 2], [3, 4]], index=[1.0, 2.0], columns=["1.", "2."]) + df = DataFrame([[1, 2], [3, 4]], + index=[1.0, 2.0], + columns=["1.", "2."]) dfjson = df.to_json(orient="table") msg = "cannot pass both convert_axes and orient='table'" with pytest.raises(ValueError, match=msg): @@ -1434,26 +1484,45 @@ def test_read_json_table_convert_axes_raises(self): [ ( DataFrame([[1, 2], [4, 5]], columns=["a", "b"]), - {"columns": ["a", "b"], "data": [[1, 2], [4, 5]]}, + { + "columns": ["a", "b"], + "data": [[1, 2], [4, 5]] + }, ), ( - DataFrame([[1, 2], [4, 5]], columns=["a", "b"]).rename_axis("foo"), - {"columns": ["a", "b"], "data": [[1, 2], [4, 5]]}, + DataFrame([[1, 2], [4, 5]], columns=["a", "b" + ]).rename_axis("foo"), + { + "columns": ["a", "b"], + "data": [[1, 2], [4, 5]] + }, ), ( - DataFrame( - [[1, 2], [4, 5]], columns=["a", "b"], index=[["a", "b"], ["c", "d"]] - ), - {"columns": ["a", "b"], "data": [[1, 2], [4, 5]]}, + DataFrame([[1, 2], [4, 5]], + columns=["a", "b"], + index=[["a", "b"], ["c", "d"]]), + { + "columns": ["a", "b"], + "data": [[1, 2], [4, 5]] + }, ), - (Series([1, 2, 3], name="A"), {"name": "A", "data": [1, 2, 3]}), + (Series([1, 2, 3], name="A"), { + "name": "A", + "data": [1, 2, 3] + }), ( Series([1, 2, 3], name="A").rename_axis("foo"), - {"name": "A", "data": [1, 2, 3]}, + { + "name": "A", + "data": [1, 2, 3] + }, ), ( Series([1, 2], name="A", index=[["a", "b"], ["c", "d"]]), - {"name": "A", "data": [1, 2]}, + { + "name": "A", + "data": [1, 2] + }, ), ], ) @@ -1470,12 +1539,11 @@ def test_index_false_to_json_split(self, data, expected): "data", [ (DataFrame([[1, 2], [4, 5]], columns=["a", "b"])), - (DataFrame([[1, 2], [4, 5]], columns=["a", "b"]).rename_axis("foo")), - ( - DataFrame( - [[1, 2], [4, 5]], columns=["a", "b"], index=[["a", "b"], ["c", "d"]] - ) - ), + (DataFrame([[1, 2], [4, 5]], columns=["a", "b" + ]).rename_axis("foo")), + (DataFrame([[1, 2], [4, 5]], + columns=["a", "b"], + index=[["a", "b"], ["c", "d"]])), (Series([1, 2, 3], name="A")), (Series([1, 2, 3], name="A").rename_axis("foo")), (Series([1, 2], name="A", index=[["a", "b"], ["c", "d"]])), @@ -1495,7 +1563,8 @@ def test_index_false_to_json_table(self, data): assert result == expected - @pytest.mark.parametrize("orient", ["records", "index", "columns", "values"]) + @pytest.mark.parametrize("orient", + ["records", "index", "columns", "values"]) def test_index_false_error_to_json(self, orient): # GH 17394 # Testing error message from to_json with index=False @@ -1518,10 +1587,12 @@ def test_index_false_from_json_to_json(self, orient, index): def test_read_timezone_information(self): # GH 25546 - result = read_json( - '{"2019-01-01T11:00:00.000Z":88}', typ="series", orient="index" - ) - expected = Series([88], index=DatetimeIndex(["2019-01-01 11:00:00"], tz="UTC")) + result = read_json('{"2019-01-01T11:00:00.000Z":88}', + typ="series", + orient="index") + expected = Series([88], + index=DatetimeIndex(["2019-01-01 11:00:00"], + tz="UTC")) tm.assert_series_equal(result, expected) @pytest.mark.parametrize( @@ -1539,9 +1610,8 @@ def test_read_json_with_url_value(self, url): expected = DataFrame({"url": [url]}) tm.assert_frame_equal(result, expected) - @pytest.mark.parametrize( - "date_format,key", [("epoch", 86400000), ("iso", "P1DT0H0M0S")] - ) + @pytest.mark.parametrize("date_format,key", [("epoch", 86400000), + ("iso", "P1DT0H0M0S")]) def test_timedelta_as_label(self, date_format, key): df = DataFrame([[1]], columns=[pd.Timedelta("1D")]) expected = f'{{"{key}":{{"0":1}}}}' @@ -1559,15 +1629,13 @@ def test_timedelta_as_label(self, date_format, key): "split", "", marks=pytest.mark.xfail( - reason="Produces JSON but not in a consistent manner" - ), + reason="Produces JSON but not in a consistent manner"), ), pytest.param( "table", "", marks=pytest.mark.xfail( - reason="Produces JSON but not in a consistent manner" - ), + reason="Produces JSON but not in a consistent manner"), ), ], ) @@ -1729,8 +1797,7 @@ def test_emca_262_nan_inf_support(self): data = '["a", NaN, "NaN", Infinity, "Infinity", -Infinity, "-Infinity"]' result = read_json(data) expected = DataFrame( - ["a", np.nan, "NaN", np.inf, "Infinity", -np.inf, "-Infinity"] - ) + ["a", np.nan, "NaN", np.inf, "Infinity", -np.inf, "-Infinity"]) tm.assert_frame_equal(result, expected) def test_deprecate_numpy_argument_read_json(self): @@ -1742,20 +1809,25 @@ def test_deprecate_numpy_argument_read_json(self): def test_frame_int_overflow(self): # GH 30320 - encoded_json = json.dumps([{"col": "31900441201190696999"}, {"col": "Text"}]) + encoded_json = json.dumps([{ + "col": "31900441201190696999" + }, { + "col": "Text" + }]) expected = DataFrame({"col": ["31900441201190696999", "Text"]}) result = read_json(encoded_json) tm.assert_frame_equal(result, expected) @pytest.mark.parametrize( "dataframe,expected", - [ - ( - DataFrame({"x": [1, 2, 3], "y": ["a", "b", "c"]}), - '{"(0, \'x\')":1,"(0, \'y\')":"a","(1, \'x\')":2,' - '"(1, \'y\')":"b","(2, \'x\')":3,"(2, \'y\')":"c"}', - ) - ], + [( + DataFrame({ + "x": [1, 2, 3], + "y": ["a", "b", "c"] + }), + '{"(0, \'x\')":1,"(0, \'y\')":"a","(1, \'x\')":2,' + '"(1, \'y\')":"b","(2, \'x\')":3,"(2, \'y\')":"c"}', + )], ) def test_json_multiindex(self, dataframe, expected): series = dataframe.stack() @@ -1769,12 +1841,12 @@ def test_to_s3(self, s3_resource, s3so): # GH 28375 mock_bucket_name, target_file = "pandas-test", "test.json" df = DataFrame({"x": [1, 2, 3], "y": [2, 4, 6]}) - df.to_json(f"s3://{mock_bucket_name}/{target_file}", storage_options=s3so) + df.to_json(f"s3://{mock_bucket_name}/{target_file}", + storage_options=s3so) timeout = 5 while True: - if target_file in ( - obj.key for obj in s3_resource.Bucket("pandas-test").objects.all() - ): + if target_file in (obj.key for obj in s3_resource.Bucket( + "pandas-test").objects.all()): break time.sleep(0.1) timeout -= 0.1 @@ -1803,20 +1875,20 @@ def test_to_json_multiindex_escape(self): columns=["foo", "bar"], ).stack() result = df.to_json() - expected = ( - "{\"(Timestamp('2017-01-20 00:00:00'), 'foo')\":true," - "\"(Timestamp('2017-01-20 00:00:00'), 'bar')\":true," - "\"(Timestamp('2017-01-21 00:00:00'), 'foo')\":true," - "\"(Timestamp('2017-01-21 00:00:00'), 'bar')\":true," - "\"(Timestamp('2017-01-22 00:00:00'), 'foo')\":true," - "\"(Timestamp('2017-01-22 00:00:00'), 'bar')\":true," - "\"(Timestamp('2017-01-23 00:00:00'), 'foo')\":true," - "\"(Timestamp('2017-01-23 00:00:00'), 'bar')\":true}" - ) + expected = ("{\"(Timestamp('2017-01-20 00:00:00'), 'foo')\":true," + "\"(Timestamp('2017-01-20 00:00:00'), 'bar')\":true," + "\"(Timestamp('2017-01-21 00:00:00'), 'foo')\":true," + "\"(Timestamp('2017-01-21 00:00:00'), 'bar')\":true," + "\"(Timestamp('2017-01-22 00:00:00'), 'foo')\":true," + "\"(Timestamp('2017-01-22 00:00:00'), 'bar')\":true," + "\"(Timestamp('2017-01-23 00:00:00'), 'foo')\":true," + "\"(Timestamp('2017-01-23 00:00:00'), 'bar')\":true}") assert result == expected def test_to_json_series_of_objects(self): + class _TestObject: + def __init__(self, a, b, _c, d): self.a = a self.b = b @@ -1834,13 +1906,21 @@ def e(self): "data,expected", [ ( - Series({0: -6 + 8j, 1: 0 + 1j, 2: 9 - 5j}), + Series({ + 0: -6 + 8j, + 1: 0 + 1j, + 2: 9 - 5j + }), '{"0":{"imag":8.0,"real":-6.0},' '"1":{"imag":1.0,"real":0.0},' '"2":{"imag":-5.0,"real":9.0}}', ), ( - Series({0: -9.39 + 0.66j, 1: 3.95 + 9.32j, 2: 4.03 - 0.17j}), + Series({ + 0: -9.39 + 0.66j, + 1: 3.95 + 9.32j, + 2: 4.03 - 0.17j + }), '{"0":{"imag":0.66,"real":-9.39},' '"1":{"imag":9.32,"real":3.95},' '"2":{"imag":-0.17,"real":4.03}}', @@ -1853,9 +1933,8 @@ def e(self): '"1":{"imag":-10.0,"real":0.0}}}', ), ( - DataFrame( - [[-0.28 + 0.34j, -1.08 - 0.39j], [0.41 - 0.34j, -0.78 - 1.35j]] - ), + DataFrame([[-0.28 + 0.34j, -1.08 - 0.39j], + [0.41 - 0.34j, -0.78 - 1.35j]]), '{"0":{"0":{"imag":0.34,"real":-0.28},' '"1":{"imag":-0.34,"real":0.41}},' '"1":{"0":{"imag":-0.39,"real":-1.08},' diff --git a/pandas/tests/io/pytables/common.py b/pandas/tests/io/pytables/common.py index f86cd594bb9ed..0cdbb395fb9b6 100644 --- a/pandas/tests/io/pytables/common.py +++ b/pandas/tests/io/pytables/common.py @@ -36,7 +36,11 @@ def create_tempfile(path): # contextmanager to ensure the file cleanup @contextmanager -def ensure_clean_store(path, mode="a", complevel=None, complib=None, fletcher32=False): +def ensure_clean_store(path, + mode="a", + complevel=None, + complib=None, + fletcher32=False): try: @@ -44,9 +48,11 @@ def ensure_clean_store(path, mode="a", complevel=None, complib=None, fletcher32= if not len(os.path.dirname(path)): path = create_tempfile(path) - store = HDFStore( - path, mode=mode, complevel=complevel, complib=complib, fletcher32=False - ) + store = HDFStore(path, + mode=mode, + complevel=complevel, + complib=complib, + fletcher32=False) yield store finally: safe_close(store) diff --git a/pandas/tests/util/test_assert_series_equal.py b/pandas/tests/util/test_assert_series_equal.py index b0c4371bffb66..b8c20e10b0e73 100644 --- a/pandas/tests/util/test_assert_series_equal.py +++ b/pandas/tests/util/test_assert_series_equal.py @@ -93,9 +93,15 @@ def test_series_not_equal_value_mismatch(data1, data2): @pytest.mark.parametrize( "kwargs", [ - {"dtype": "float64"}, # dtype mismatch - {"index": [1, 2, 4]}, # index mismatch - {"name": "foo"}, # name mismatch + { + "dtype": "float64" + }, # dtype mismatch + { + "index": [1, 2, 4] + }, # index mismatch + { + "name": "foo" + }, # name mismatch ], ) def test_series_not_equal_metadata_mismatch(kwargs): @@ -114,9 +120,7 @@ def test_less_precise(data1, data2, dtype, decimals): s1 = Series([data1], dtype=dtype) s2 = Series([data2], dtype=dtype) - if decimals in (5, 10) or ( - decimals >= 3 and abs(data1 - data2) >= 0.0005 - ): + if decimals in (5, 10) or (decimals >= 3 and abs(data1 - data2) >= 0.0005): if is_extension_array_dtype(dtype): msg = "ExtensionArray are different" else: @@ -139,11 +143,19 @@ def test_less_precise(data1, data2, dtype, decimals): # MultiIndex ( DataFrame.from_records( - {"a": [1, 2], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"] - ).c, + { + "a": [1, 2], + "b": [2.1, 1.5], + "c": ["l1", "l2"] + }, + index=["a", "b"]).c, DataFrame.from_records( - {"a": [1.0, 2.0], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"] - ).c, + { + "a": [1.0, 2.0], + "b": [2.1, 1.5], + "c": ["l1", "l2"] + }, + index=["a", "b"]).c, "MultiIndex level \\[0\\] are different", ), ], @@ -272,6 +284,7 @@ def test_assert_series_equal_interval_dtype_mismatch(): def test_series_equal_series_type(): + class MySeries(Series): pass diff --git a/pandas/tseries/frequencies.py b/pandas/tseries/frequencies.py index a3c3befa88205..b8ef1531bdf5d 100644 --- a/pandas/tseries/frequencies.py +++ b/pandas/tseries/frequencies.py @@ -165,15 +165,10 @@ def infer_freq(index, warn: bool = True) -> str | None: if isinstance(index, ABCSeries): values = index._values - if not ( - is_datetime64_dtype(values) - or is_timedelta64_dtype(values) - or values.dtype == object - ): - raise TypeError( - "cannot infer freq from a non-convertible dtype " - f"on a Series of {index.dtype}" - ) + if not (is_datetime64_dtype(values) or is_timedelta64_dtype(values) + or values.dtype == object): + raise TypeError("cannot infer freq from a non-convertible dtype " + f"on a Series of {index.dtype}") index = values inferer: _FrequencyInferer @@ -181,10 +176,8 @@ def infer_freq(index, warn: bool = True) -> str | None: if not hasattr(index, "dtype"): pass elif is_period_dtype(index.dtype): - raise TypeError( - "PeriodIndex given. Check the `freq` attribute " - "instead of using infer_freq." - ) + raise TypeError("PeriodIndex given. Check the `freq` attribute " + "instead of using infer_freq.") elif is_timedelta64_dtype(index.dtype): # Allow TimedeltaIndex and TimedeltaArray inferer = _TimedeltaFrequencyInferer(index, warn=warn) @@ -218,8 +211,7 @@ def __init__(self, index, warn: bool = True): if hasattr(index, "tz"): if index.tz is not None: self.i8values = tzconversion.tz_convert_from_utc( - self.i8values, index.tz - ) + self.i8values, index.tz) if warn is not True: warnings.warn( @@ -233,9 +225,8 @@ def __init__(self, index, warn: bool = True): if len(index) < 3: raise ValueError("Need at least 3 dates to infer frequency") - self.is_monotonic = ( - self.index._is_monotonic_increasing or self.index._is_monotonic_decreasing - ) + self.is_monotonic = (self.index._is_monotonic_increasing + or self.index._is_monotonic_decreasing) @cache_readonly def deltas(self) -> npt.NDArray[np.int64]: @@ -409,11 +400,8 @@ def _is_business_daily(self) -> bool: weekdays = np.mod(first_weekday + np.cumsum(shifts), 7) return bool( - np.all( - ((weekdays == 0) & (shifts == 3)) - | ((weekdays > 0) & (weekdays <= 4) & (shifts == 1)) - ) - ) + np.all(((weekdays == 0) & (shifts == 3)) + | ((weekdays > 0) & (weekdays <= 4) & (shifts == 1)))) def _get_wom_rule(self) -> str | None: # FIXME: dont leave commented-out @@ -440,6 +428,7 @@ def _get_wom_rule(self) -> str | None: class _TimedeltaFrequencyInferer(_FrequencyInferer): + def _infer_daily_rule(self): if self.is_unique: return self._get_daily_rule() @@ -486,9 +475,8 @@ def is_subperiod(source, target) -> bool: if _is_annual(target): if _is_quarterly(source): - return _quarter_months_conform( - get_rule_month(source), get_rule_month(target) - ) + return _quarter_months_conform(get_rule_month(source), + get_rule_month(target)) return source in {"D", "C", "B", "M", "H", "T", "S", "L", "U", "N"} elif _is_quarterly(target): return source in {"D", "C", "B", "M", "H", "T", "S", "L", "U", "N"} diff --git a/pandas/tseries/holiday.py b/pandas/tseries/holiday.py index 50b06ce143163..01d60cbdf3e12 100644 --- a/pandas/tseries/holiday.py +++ b/pandas/tseries/holiday.py @@ -219,10 +219,10 @@ class from pandas.tseries.offsets self.month = month self.day = day self.offset = offset - self.start_date = ( - Timestamp(start_date) if start_date is not None else start_date - ) - self.end_date = Timestamp(end_date) if end_date is not None else end_date + self.start_date = (Timestamp(start_date) + if start_date is not None else start_date) + self.end_date = Timestamp( + end_date) if end_date is not None else end_date self.observance = observance assert days_of_week is None or type(days_of_week) == tuple self.days_of_week = days_of_week @@ -270,21 +270,18 @@ def dates(self, start_date, end_date, return_name=False): dates = self._reference_dates(start_date, end_date) holiday_dates = self._apply_rule(dates) if self.days_of_week is not None: - holiday_dates = holiday_dates[ - np.in1d(holiday_dates.dayofweek, self.days_of_week) - ] + holiday_dates = holiday_dates[np.in1d(holiday_dates.dayofweek, + self.days_of_week)] if self.start_date is not None: filter_start_date = max( - self.start_date.tz_localize(filter_start_date.tz), filter_start_date - ) + self.start_date.tz_localize(filter_start_date.tz), + filter_start_date) if self.end_date is not None: filter_end_date = min( - self.end_date.tz_localize(filter_end_date.tz), filter_end_date - ) - holiday_dates = holiday_dates[ - (holiday_dates >= filter_start_date) & (holiday_dates <= filter_end_date) - ] + self.end_date.tz_localize(filter_end_date.tz), filter_end_date) + holiday_dates = holiday_dates[(holiday_dates >= filter_start_date) + & (holiday_dates <= filter_end_date)] if return_name: return Series(self.name, index=holiday_dates) return holiday_dates @@ -306,12 +303,10 @@ def _reference_dates(self, start_date, end_date): year_offset = DateOffset(years=1) reference_start_date = Timestamp( - datetime(start_date.year - 1, self.month, self.day) - ) + datetime(start_date.year - 1, self.month, self.day)) reference_end_date = Timestamp( - datetime(end_date.year + 1, self.month, self.day) - ) + datetime(end_date.year + 1, self.month, self.day)) # Don't process unnecessary holidays dates = date_range( start=reference_start_date, @@ -377,6 +372,7 @@ def get_calendar(name): class HolidayCalendarMetaClass(type): + def __new__(cls, clsname, bases, attrs): calendar_class = super().__new__(cls, clsname, bases, attrs) register(calendar_class) @@ -452,7 +448,8 @@ def holidays(self, start=None, end=None, return_name=False): # If we don't have a cache or the dates are outside the prior cache, we # get them again - if self._cache is None or start < self._cache[0] or end > self._cache[1]: + if self._cache is None or start < self._cache[0] or end > self._cache[ + 1]: pre_holidays = [ rule.dates(start, end, return_name=True) for rule in self.rules ] @@ -525,16 +522,22 @@ def merge(self, other, inplace=False): return holidays -USMemorialDay = Holiday( - "Memorial Day", month=5, day=31, offset=DateOffset(weekday=MO(-1)) -) -USLaborDay = Holiday("Labor Day", month=9, day=1, offset=DateOffset(weekday=MO(1))) -USColumbusDay = Holiday( - "Columbus Day", month=10, day=1, offset=DateOffset(weekday=MO(2)) -) -USThanksgivingDay = Holiday( - "Thanksgiving Day", month=11, day=1, offset=DateOffset(weekday=TH(4)) -) +USMemorialDay = Holiday("Memorial Day", + month=5, + day=31, + offset=DateOffset(weekday=MO(-1))) +USLaborDay = Holiday("Labor Day", + month=9, + day=1, + offset=DateOffset(weekday=MO(1))) +USColumbusDay = Holiday("Columbus Day", + month=10, + day=1, + offset=DateOffset(weekday=MO(2))) +USThanksgivingDay = Holiday("Thanksgiving Day", + month=11, + day=1, + offset=DateOffset(weekday=TH(4))) USMartinLutherKingJr = Holiday( "Birthday of Martin Luther King, Jr.", start_date=datetime(1986, 1, 1), @@ -542,12 +545,16 @@ def merge(self, other, inplace=False): day=1, offset=DateOffset(weekday=MO(3)), ) -USPresidentsDay = Holiday( - "Washington’s Birthday", month=2, day=1, offset=DateOffset(weekday=MO(3)) -) +USPresidentsDay = Holiday("Washington’s Birthday", + month=2, + day=1, + offset=DateOffset(weekday=MO(3))) GoodFriday = Holiday("Good Friday", month=1, day=1, offset=[Easter(), Day(-2)]) -EasterMonday = Holiday("Easter Monday", month=1, day=1, offset=[Easter(), Day(1)]) +EasterMonday = Holiday("Easter Monday", + month=1, + day=1, + offset=[Easter(), Day(1)]) class USFederalHolidayCalendar(AbstractHolidayCalendar): @@ -569,7 +576,8 @@ class USFederalHolidayCalendar(AbstractHolidayCalendar): start_date="2021-06-18", observance=nearest_workday, ), - Holiday("Independence Day", month=7, day=4, observance=nearest_workday), + Holiday("Independence Day", month=7, day=4, + observance=nearest_workday), USLaborDay, USColumbusDay, Holiday("Veterans Day", month=11, day=11, observance=nearest_workday), @@ -578,7 +586,10 @@ class USFederalHolidayCalendar(AbstractHolidayCalendar): ] -def HolidayCalendarFactory(name, base, other, base_class=AbstractHolidayCalendar): +def HolidayCalendarFactory(name, + base, + other, + base_class=AbstractHolidayCalendar): rules = AbstractHolidayCalendar.merge_class(base, other) - calendar_class = type(name, (base_class,), {"rules": rules, "name": name}) + calendar_class = type(name, (base_class, ), {"rules": rules, "name": name}) return calendar_class diff --git a/setup.py b/setup.py index 471ded06502af..72c78bbd90b28 100755 --- a/setup.py +++ b/setup.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 - """ Parts of this file were taken from the pyzmq project (https://github.com/zeromq/pyzmq) which have been permitted for use under the @@ -47,15 +46,16 @@ def is_platform_mac(): ) from Cython.Build import cythonize - _CYTHON_INSTALLED = parse_version(_CYTHON_VERSION) >= parse_version(min_cython_ver) + _CYTHON_INSTALLED = parse_version(_CYTHON_VERSION) >= parse_version( + min_cython_ver) except ImportError: _CYTHON_VERSION = None _CYTHON_INSTALLED = False cythonize = lambda x, *args, **kwargs: x # dummy func - _pxi_dep_template = { - "algos": ["_libs/algos_common_helper.pxi.in", "_libs/algos_take_helper.pxi.in"], + "algos": + ["_libs/algos_common_helper.pxi.in", "_libs/algos_take_helper.pxi.in"], "hashtable": [ "_libs/hashtable_class_helper.pxi.in", "_libs/hashtable_func_helper.pxi.in", @@ -75,6 +75,7 @@ def is_platform_mac(): class build_ext(_build_ext): + @classmethod def render_templates(cls, pxifiles): for pxifile in pxifiles: @@ -82,10 +83,8 @@ def render_templates(cls, pxifiles): assert pxifile.endswith(".pxi.in") outfile = pxifile[:-3] - if ( - os.path.exists(outfile) - and os.stat(pxifile).st_mtime < os.stat(outfile).st_mtime - ): + if (os.path.exists(outfile) + and os.stat(pxifile).st_mtime < os.stat(outfile).st_mtime): # if .pxi.in is not updated, no need to output .pxi continue @@ -143,14 +142,14 @@ def initialize_options(self): continue if os.path.splitext(f)[-1] in ( - ".pyc", - ".so", - ".o", - ".pyo", - ".pyd", - ".c", - ".cpp", - ".orig", + ".pyc", + ".so", + ".o", + ".pyo", + ".pyd", + ".c", + ".cpp", + ".orig", ): self._clean_me.append(filepath) for d in dirs: @@ -249,8 +248,7 @@ def run(self): sourcefile = pyxfile[:-3] + extension msg = ( f"{extension}-source file '{sourcefile}' not found.\n" - "Run 'setup.py cython' before sdist." - ) + "Run 'setup.py cython' before sdist.") assert os.path.isfile(sourcefile), msg sdist_class.run(self) @@ -269,8 +267,7 @@ def check_cython_extensions(self, extensions): f"""Cython-generated file '{src}' not found. Cython is required to compile pandas from a development branch. Please install Cython or download a release package of pandas. - """ - ) + """) def build_extensions(self): self.check_cython_extensions(self.extensions) @@ -320,13 +317,11 @@ def run(self): if debugging_symbols_requested: sys.argv.remove("--with-debugging-symbols") - if sys.byteorder == "big": endian_macro = [("__BIG_ENDIAN__", "1")] else: endian_macro = [("__LITTLE_ENDIAN__", "1")] - extra_compile_args = [] extra_link_args = [] if is_platform_windows(): @@ -349,15 +344,12 @@ def run(self): if is_platform_mac(): if "MACOSX_DEPLOYMENT_TARGET" not in os.environ: current_system = platform.mac_ver()[0] - python_target = get_config_vars().get( - "MACOSX_DEPLOYMENT_TARGET", current_system - ) + python_target = get_config_vars().get("MACOSX_DEPLOYMENT_TARGET", + current_system) target_macos_version = "10.9" parsed_macos_version = parse_version(target_macos_version) - if ( - parse_version(str(python_target)) < parsed_macos_version - and parse_version(current_system) >= parsed_macos_version - ): + if (parse_version(str(python_target)) < parsed_macos_version + and parse_version(current_system) >= parsed_macos_version): os.environ["MACOSX_DEPLOYMENT_TARGET"] = target_macos_version if sys.version_info[:2] == (3, 8): # GH 33239 @@ -389,10 +381,10 @@ def run(self): # cython+numpy version mismatches. macros.append(("NPY_NO_DEPRECATED_API", "0")) - # ---------------------------------------------------------------------- # Specification of Dependencies + # TODO(cython#4518): Need to check to see if e.g. `linetrace` has changed and # possibly re-compile. def maybe_cythonize(extensions, *args, **kwargs): @@ -411,8 +403,7 @@ def maybe_cythonize(extensions, *args, **kwargs): if _CYTHON_VERSION: raise RuntimeError( f"Cannot cythonize with old Cython version ({_CYTHON_VERSION} " - f"installed, needs {min_cython_ver})" - ) + f"installed, needs {min_cython_ver})") raise RuntimeError("Cannot cythonize without Cython installed.") # reuse any parallel arguments provided for compilation to cythonize @@ -444,40 +435,61 @@ def srcpath(name=None, suffix=".pyx", subdir="src"): "include": klib_include, "depends": _pxi_dep["algos"], }, - "_libs.arrays": {"pyxfile": "_libs/arrays"}, - "_libs.groupby": {"pyxfile": "_libs/groupby"}, - "_libs.hashing": {"pyxfile": "_libs/hashing", "depends": []}, + "_libs.arrays": { + "pyxfile": "_libs/arrays" + }, + "_libs.groupby": { + "pyxfile": "_libs/groupby" + }, + "_libs.hashing": { + "pyxfile": "_libs/hashing", + "depends": [] + }, "_libs.hashtable": { - "pyxfile": "_libs/hashtable", - "include": klib_include, - "depends": ( - ["pandas/_libs/src/klib/khash_python.h", "pandas/_libs/src/klib/khash.h"] - + _pxi_dep["hashtable"] - ), + "pyxfile": + "_libs/hashtable", + "include": + klib_include, + "depends": ([ + "pandas/_libs/src/klib/khash_python.h", + "pandas/_libs/src/klib/khash.h" + ] + _pxi_dep["hashtable"]), }, "_libs.index": { "pyxfile": "_libs/index", "include": klib_include, "depends": _pxi_dep["index"], }, - "_libs.indexing": {"pyxfile": "_libs/indexing"}, - "_libs.internals": {"pyxfile": "_libs/internals"}, + "_libs.indexing": { + "pyxfile": "_libs/indexing" + }, + "_libs.internals": { + "pyxfile": "_libs/internals" + }, "_libs.interval": { "pyxfile": "_libs/interval", "include": klib_include, "depends": _pxi_dep["interval"], }, - "_libs.join": {"pyxfile": "_libs/join", "include": klib_include}, + "_libs.join": { + "pyxfile": "_libs/join", + "include": klib_include + }, "_libs.lib": { "pyxfile": "_libs/lib", "depends": lib_depends + tseries_depends, "include": klib_include, # due to tokenizer import "sources": ["pandas/_libs/src/parser/tokenizer.c"], }, - "_libs.missing": {"pyxfile": "_libs/missing", "depends": tseries_depends}, + "_libs.missing": { + "pyxfile": "_libs/missing", + "depends": tseries_depends + }, "_libs.parsers": { - "pyxfile": "_libs/parsers", - "include": klib_include + ["pandas/_libs/src"], + "pyxfile": + "_libs/parsers", + "include": + klib_include + ["pandas/_libs/src"], "depends": [ "pandas/_libs/src/parser/tokenizer.h", "pandas/_libs/src/parser/io.h", @@ -487,17 +499,42 @@ def srcpath(name=None, suffix=".pyx", subdir="src"): "pandas/_libs/src/parser/io.c", ], }, - "_libs.reduction": {"pyxfile": "_libs/reduction"}, - "_libs.ops": {"pyxfile": "_libs/ops"}, - "_libs.ops_dispatch": {"pyxfile": "_libs/ops_dispatch"}, - "_libs.properties": {"pyxfile": "_libs/properties"}, - "_libs.reshape": {"pyxfile": "_libs/reshape", "depends": []}, - "_libs.sparse": {"pyxfile": "_libs/sparse", "depends": _pxi_dep["sparse"]}, - "_libs.tslib": {"pyxfile": "_libs/tslib", "depends": tseries_depends}, - "_libs.tslibs.base": {"pyxfile": "_libs/tslibs/base"}, - "_libs.tslibs.ccalendar": {"pyxfile": "_libs/tslibs/ccalendar"}, - "_libs.tslibs.ctime": {"pyxfile": "_libs/tslibs/ctime"}, - "_libs.tslibs.dtypes": {"pyxfile": "_libs/tslibs/dtypes"}, + "_libs.reduction": { + "pyxfile": "_libs/reduction" + }, + "_libs.ops": { + "pyxfile": "_libs/ops" + }, + "_libs.ops_dispatch": { + "pyxfile": "_libs/ops_dispatch" + }, + "_libs.properties": { + "pyxfile": "_libs/properties" + }, + "_libs.reshape": { + "pyxfile": "_libs/reshape", + "depends": [] + }, + "_libs.sparse": { + "pyxfile": "_libs/sparse", + "depends": _pxi_dep["sparse"] + }, + "_libs.tslib": { + "pyxfile": "_libs/tslib", + "depends": tseries_depends + }, + "_libs.tslibs.base": { + "pyxfile": "_libs/tslibs/base" + }, + "_libs.tslibs.ccalendar": { + "pyxfile": "_libs/tslibs/ccalendar" + }, + "_libs.tslibs.ctime": { + "pyxfile": "_libs/tslibs/ctime" + }, + "_libs.tslibs.dtypes": { + "pyxfile": "_libs/tslibs/dtypes" + }, "_libs.tslibs.conversion": { "pyxfile": "_libs/tslibs/conversion", "depends": tseries_depends, @@ -507,10 +544,14 @@ def srcpath(name=None, suffix=".pyx", subdir="src"): "pyxfile": "_libs/tslibs/fields", "depends": tseries_depends, }, - "_libs.tslibs.nattype": {"pyxfile": "_libs/tslibs/nattype"}, + "_libs.tslibs.nattype": { + "pyxfile": "_libs/tslibs/nattype" + }, "_libs.tslibs.np_datetime": { - "pyxfile": "_libs/tslibs/np_datetime", - "depends": tseries_depends, + "pyxfile": + "_libs/tslibs/np_datetime", + "depends": + tseries_depends, "sources": [ "pandas/_libs/tslibs/src/datetime/np_datetime.c", "pandas/_libs/tslibs/src/datetime/np_datetime_strings.c", @@ -543,22 +584,34 @@ def srcpath(name=None, suffix=".pyx", subdir="src"): "pyxfile": "_libs/tslibs/timestamps", "depends": tseries_depends, }, - "_libs.tslibs.timezones": {"pyxfile": "_libs/tslibs/timezones"}, + "_libs.tslibs.timezones": { + "pyxfile": "_libs/tslibs/timezones" + }, "_libs.tslibs.tzconversion": { "pyxfile": "_libs/tslibs/tzconversion", "depends": tseries_depends, }, - "_libs.tslibs.vectorized": {"pyxfile": "_libs/tslibs/vectorized"}, - "_libs.testing": {"pyxfile": "_libs/testing"}, + "_libs.tslibs.vectorized": { + "pyxfile": "_libs/tslibs/vectorized" + }, + "_libs.testing": { + "pyxfile": "_libs/testing" + }, "_libs.window.aggregations": { "pyxfile": "_libs/window/aggregations", "language": "c++", "suffix": ".cpp", "depends": ["pandas/_libs/src/skiplist.h"], }, - "_libs.window.indexers": {"pyxfile": "_libs/window/indexers"}, - "_libs.writers": {"pyxfile": "_libs/writers"}, - "io.sas._sas": {"pyxfile": "io/sas/sas"}, + "_libs.window.indexers": { + "pyxfile": "_libs/window/indexers" + }, + "_libs.writers": { + "pyxfile": "_libs/writers" + }, + "io.sas._sas": { + "pyxfile": "io/sas/sas" + }, } extensions = [] @@ -575,11 +628,9 @@ def srcpath(name=None, suffix=".pyx", subdir="src"): undef_macros = [] - if ( - sys.platform == "zos" - and data.get("language") == "c++" - and os.path.basename(os.environ.get("CXX", "/bin/xlc++")) in ("xlc", "xlc++") - ): + if (sys.platform == "zos" + and data.get("language") == "c++" and os.path.basename( + os.environ.get("CXX", "/bin/xlc++")) in ("xlc", "xlc++")): data.get("macros", macros).append(("__s390__", "1")) extra_compile_args.append("-qlanglvl=extended0x:nolibext") undef_macros.append("_POSIX_THREADS") @@ -614,20 +665,17 @@ def srcpath(name=None, suffix=".pyx", subdir="src"): "pandas/_libs/src/ujson/lib/ultrajson.h", "pandas/_libs/src/ujson/python/date_conversions.h", ], - sources=( - [ - "pandas/_libs/src/ujson/python/ujson.c", - "pandas/_libs/src/ujson/python/objToJSON.c", - "pandas/_libs/src/ujson/python/date_conversions.c", - "pandas/_libs/src/ujson/python/JSONtoObj.c", - "pandas/_libs/src/ujson/lib/ultrajsonenc.c", - "pandas/_libs/src/ujson/lib/ultrajsondec.c", - ] - + [ - "pandas/_libs/tslibs/src/datetime/np_datetime.c", - "pandas/_libs/tslibs/src/datetime/np_datetime_strings.c", - ] - ), + sources=([ + "pandas/_libs/src/ujson/python/ujson.c", + "pandas/_libs/src/ujson/python/objToJSON.c", + "pandas/_libs/src/ujson/python/date_conversions.c", + "pandas/_libs/src/ujson/python/JSONtoObj.c", + "pandas/_libs/src/ujson/lib/ultrajsonenc.c", + "pandas/_libs/src/ujson/lib/ultrajsondec.c", + ] + [ + "pandas/_libs/tslibs/src/datetime/np_datetime.c", + "pandas/_libs/tslibs/src/datetime/np_datetime_strings.c", + ]), include_dirs=[ "pandas/_libs/src/ujson/python", "pandas/_libs/src/ujson/lib", @@ -639,17 +687,16 @@ def srcpath(name=None, suffix=".pyx", subdir="src"): define_macros=macros, ) - extensions.append(ujson_ext) # ---------------------------------------------------------------------- - if __name__ == "__main__": # Freeze to support parallel compilation when using spawn instead of fork multiprocessing.freeze_support() setup( version=versioneer.get_version(), - ext_modules=maybe_cythonize(extensions, compiler_directives=directives), + ext_modules=maybe_cythonize(extensions, + compiler_directives=directives), cmdclass=cmdclass, )