|
| 1 | +from importlib import import_module |
| 2 | + |
| 3 | +import pint |
| 4 | +from pint.quantity import Quantity |
| 5 | +from pint.unit import Unit |
| 6 | +# TODO is it possible to import pint-xarray from within xarray if pint is present? |
| 7 | +import xarray as xr |
| 8 | +import numpy as np |
| 9 | +from xarray.core.npcompat import IS_NEP18_ACTIVE |
| 10 | + |
| 11 | + |
| 12 | +if not hasattr(Quantity, "__array_function__"): |
| 13 | + raise ImportError("Imported version of pint does not implement " |
| 14 | + "__array_function__ yet") |
| 15 | + |
| 16 | +if not IS_NEP18_ACTIVE: |
| 17 | + raise ImportError("NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled") |
| 18 | + |
| 19 | +try: |
| 20 | + mpl = import_module("matplotlib") |
| 21 | + has_mpl = True |
| 22 | +except ImportError: |
| 23 | + has_mpl = False |
| 24 | +else: |
| 25 | + # TODO can we do this without initialising a Unit Registry? |
| 26 | + unit_registry = pint.UnitRegistry() |
| 27 | + unit_registry.setup_matplotlib(True) |
| 28 | + |
| 29 | + |
| 30 | +# TODO could/should we overwrite xr.open_dataset and xr.open_mfdataset to make |
| 31 | +# them apply units upon loading??? |
| 32 | +# TODO could even override the decode_cf kwarg? |
| 33 | + |
| 34 | +# TODO docstrings |
| 35 | +# TODO type hints |
| 36 | +# TODO f-strings |
| 37 | + |
| 38 | +def array_attach_units(data, unit, convert_from=None): |
| 39 | + if isinstance(data, Quantity): |
| 40 | + if not convert_from: |
| 41 | + raise ValueError( |
| 42 | + "cannot attach unit {unit} to quantity ({data.units})".format( |
| 43 | + unit=unit, data=data |
| 44 | + ) |
| 45 | + ) |
| 46 | + elif isinstance(convert_from, Unit): |
| 47 | + data = data.magnitude |
| 48 | + elif convert_from is True: # intentionally accept exactly true |
| 49 | + if data.check(unit): |
| 50 | + convert_from = data.units |
| 51 | + data = data.magnitude |
| 52 | + else: |
| 53 | + raise ValueError( |
| 54 | + "cannot convert quantity ({data.units}) to {unit}".format( |
| 55 | + unit=unit, data=data |
| 56 | + ) |
| 57 | + ) |
| 58 | + else: |
| 59 | + raise ValueError( |
| 60 | + "cannot convert from invalid unit {convert_from}".format( |
| 61 | + convert_from=convert_from |
| 62 | + ) |
| 63 | + ) |
| 64 | + |
| 65 | + # to make sure we also encounter the case of "equal if converted" |
| 66 | + if convert_from is not None: |
| 67 | + quantity = (data * convert_from).to( |
| 68 | + unit |
| 69 | + if isinstance(unit, Unit) |
| 70 | + else unit.dimensionless |
| 71 | + ) |
| 72 | + else: |
| 73 | + try: |
| 74 | + quantity = data * unit |
| 75 | + except np.core._exceptions.UFuncTypeError: |
| 76 | + if unit != 1: |
| 77 | + raise |
| 78 | + |
| 79 | + quantity = data |
| 80 | + |
| 81 | + return quantity |
| 82 | + |
| 83 | + |
| 84 | +# TODO Error checking (that data is actually a quantity etc) |
| 85 | + |
| 86 | +# TODO refactor with an apply_to(da, data_method) function? |
| 87 | + |
| 88 | + |
| 89 | +@xr.register_dataarray_accessor("pint") |
| 90 | +class PintDataArrayAccessor: |
| 91 | + def __init__(self, da): |
| 92 | + self.da = da |
| 93 | + |
| 94 | + def quantify(self, units=None, unit_registry=None, decode_cf=False): |
| 95 | + # TODO read and decode units according to CF conventions (see MetPy) |
| 96 | + if not units: |
| 97 | + if decode_cf: |
| 98 | + # TODO unit = Unit(_decode_cf(self.da.attrs['units'])) |
| 99 | + raise NotImplementedError |
| 100 | + else: |
| 101 | + units = Unit(self.da.attrs['units']) |
| 102 | + |
| 103 | + quantity = array_attach_units(self.da.data, units) |
| 104 | + # TODO should we (temporarily) remove the attrs here? |
| 105 | + return xr.DataArray(dim=self.da.dims, data=quantity, |
| 106 | + coords=self.da.coords, attrs=self.da.attrs, |
| 107 | + encoding=self.da.encoding) |
| 108 | + |
| 109 | + def dequantify(self, encode_cf=True): |
| 110 | + da = xr.DataArray(dim=self.da.dims, data=self.da.pint.magnitude, |
| 111 | + coords=self.da.coords, attrs=self.da.attrs, |
| 112 | + encoding=self.da.encoding) |
| 113 | + da.attrs['units'] = self.da.pint.units |
| 114 | + return da |
| 115 | + |
| 116 | + @property |
| 117 | + def units(self): |
| 118 | + return self.da.data.units |
| 119 | + |
| 120 | + @units.setter |
| 121 | + def units(self, units): |
| 122 | + quantity = array_attach_units(self.da.data, units) |
| 123 | + self.da = xr.DataArray(dim=self.da.dims, data=quantity, |
| 124 | + coords=self.da.coords, attrs=self.da.attrs, |
| 125 | + encoding=self.da.encoding) |
| 126 | + |
| 127 | + @property |
| 128 | + def magnitude(self): |
| 129 | + return self.da.data.magnitude |
| 130 | + |
| 131 | + @magnitude.setter |
| 132 | + def magnitude(self, da): |
| 133 | + self.da = xr.DataArray(dim=self.da.dims, data=da.data, |
| 134 | + coords=self.da.coords, attrs=self.da.attrs, |
| 135 | + encoding=self.da.encoding) |
| 136 | + |
| 137 | + def to(self, units): |
| 138 | + quantity = self.da.data.to(units) |
| 139 | + return xr.DataArray(dim=self.da.dims, data=quantity, |
| 140 | + coords=self.da.coords, attrs=self.da.attrs, |
| 141 | + encoding=self.da.encoding) |
| 142 | + |
| 143 | + def to_base_units(self): |
| 144 | + quantity = self.da.data.to_base_units() |
| 145 | + return xr.DataArray(dim=self.da.dims, data=quantity, |
| 146 | + coords=self.da.coords, attrs=self.da.attrs, |
| 147 | + encoding=self.da.encoding) |
| 148 | + |
| 149 | + # TODO integrate with the uncertainties package here...? |
| 150 | + def plus_minus(self, value, error, relative=False): |
| 151 | + quantity = self.da.data.plus_minus(value, error, relative) |
| 152 | + return xr.DataArray(dim=self.da.dims, data=quantity, |
| 153 | + coords=self.da.coords, attrs=self.da.attrs, |
| 154 | + encoding=self.da.encoding) |
| 155 | + |
| 156 | + def sel(self, indexers=None, method=None, tolerance=None, drop=False, |
| 157 | + **indexers_kwargs): |
| 158 | + ... |
| 159 | + |
| 160 | + @property |
| 161 | + def loc(self): |
| 162 | + ... |
| 163 | + |
| 164 | + |
| 165 | +@xr.register_dataset_accessor("pint") |
| 166 | +class PintDatasetAccessor: |
| 167 | + def __init__(self, ds): |
| 168 | + self.ds = ds |
| 169 | + |
| 170 | + def quantify(self, unit_registry=None, decode_cf=False): |
| 171 | + quantified_vars = {name: da.pint.quantify(unit_registry=unit_registry, |
| 172 | + decode_cf=decode_cf) |
| 173 | + for name, da in self.ds.items()} |
| 174 | + return xr.Dataset(quantified_vars, attrs=self.ds.attrs, |
| 175 | + encoding=self.ds.encoding) |
| 176 | + |
| 177 | + def dequantify(self): |
| 178 | + dequantified_vars = {name: da.pint.to_base_units() |
| 179 | + for name, da in self.ds.items()} |
| 180 | + return xr.Dataset(dequantified_vars, attrs=self.ds.attrs, |
| 181 | + encoding=self.ds.encoding) |
| 182 | + |
| 183 | + def to_base_units(self): |
| 184 | + base_vars = {name: da.pint.to_base_units() |
| 185 | + for name, da in self.ds.items()} |
| 186 | + return xr.Dataset(base_vars, attrs=self.ds.attrs, encoding=self.ds.encoding) |
| 187 | + |
| 188 | + # TODO way to change every variable in ds to be expressed in a new units system? |
| 189 | + |
| 190 | + def sel(self, indexers=None, method=None, tolerance=None, drop=False, |
| 191 | + **indexers_kwargs): |
| 192 | + ... |
| 193 | + |
| 194 | + @property |
| 195 | + def loc(self): |
| 196 | + ... |
0 commit comments