|
8 | 8 | from typing import Literal, Optional, Union |
9 | 9 |
|
10 | 10 | import numpy as np |
| 11 | +import pandas as pd |
11 | 12 | from xarray import DataArray, Dataset |
12 | 13 |
|
13 | 14 | from . import plot_actions |
14 | 15 |
|
15 | 16 | # from sisl.viz.nodes.processors.grid import get_isos |
16 | 17 |
|
17 | 18 |
|
| 19 | +def _correct_pandas_string_none(dic: dict[str, DataArray]): |
| 20 | + """Correct string arguments in xarray arrays to avoid casting to nan |
| 21 | +
|
| 22 | + This is a migration notice from pandas >= 3. |
| 23 | + They change their str implementation to a custom (new default) dtype. |
| 24 | + And the result is that `None` will be parsed to `np.nan` as opposed |
| 25 | + to the old `None -> None` parsing. |
| 26 | + """ |
| 27 | + |
| 28 | + # xarray uses pandas for casting types, and pandas decided to change |
| 29 | + # the default behavior and parse [None, ""], to [nan, ""]. Passing |
| 30 | + # NaN to plotting functions when they are expecting None breaks things. |
| 31 | + # Here we make sure that we parse all NaN to None |
| 32 | + # Pandas docs on this: |
| 33 | + # https://pandas.pydata.org/docs/user_guide/migration-3-strings.html |
| 34 | + old_infer_string = pd.options.future.infer_string |
| 35 | + pd.options.future.infer_string = False |
| 36 | + |
| 37 | + for key, value in dic.items(): |
| 38 | + value = dic[key] |
| 39 | + try: |
| 40 | + dic[key] = value.where(value.notnull(), other=None) |
| 41 | + except: |
| 42 | + pass |
| 43 | + |
| 44 | + pd.options.future.infer_string = old_infer_string |
| 45 | + |
| 46 | + |
| 47 | +if int(pd.__version__.split(".")[0]) >= 3: |
| 48 | + |
| 49 | + def _correct_pandas_string_none(dic: dict[str, str]): |
| 50 | + pass |
| 51 | + |
| 52 | + |
18 | 53 | def _process_xarray_data( |
19 | 54 | data: Union[DataArray, Dataset], |
20 | 55 | x: Union[str, None] = None, |
@@ -461,6 +496,8 @@ def _draw_xarray_lines( |
461 | 496 | "border_width", |
462 | 497 | "border_color", |
463 | 498 | ) |
| 499 | + _correct_pandas_string_none(style) |
| 500 | + |
464 | 501 | for key in style_keys: |
465 | 502 | lines_style[key] = style.get(key) |
466 | 503 |
|
|
0 commit comments