Skip to content

Commit 2e5d5df

Browse files
committed
refactored to allow quantifying dataset by variable
1 parent f8a4374 commit 2e5d5df

File tree

1 file changed

+49
-26
lines changed

1 file changed

+49
-26
lines changed

pintxarray/accessors.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
# TODO docstrings
2626
# TODO type hints
27-
# TODO f-strings
2827

2928

3029
def _array_attach_units(data, unit, convert_from=None):
@@ -68,6 +67,35 @@ def _array_attach_units(data, unit, convert_from=None):
6867

6968
return quantity
7069

70+
def _get_registry(unit_registry, registry_kwargs):
71+
if unit_registry is None:
72+
if registry_kwargs is None:
73+
registry_kwargs = {}
74+
registry_kwargs.update(force_ndarray=True)
75+
# TODO should this registry object then be stored somewhere global?
76+
unit_registry = pint.UnitRegistry(**registry_kwargs)
77+
return unit_registry
78+
79+
def _decide_units(units, registry, attrs):
80+
if units is None:
81+
# TODO option to read and decode units according to CF conventions (see MetPy)?
82+
attr_units = attrs['units']
83+
units = registry.parse_expression(attr_units)
84+
elif isinstance(units, Unit):
85+
# TODO do we have to check what happens if someone passes a Unit instance
86+
# without creating a unit registry?
87+
# TODO and what happens if they pass in a Unit from a different registry
88+
pass
89+
else:
90+
units = registry.Unit(units)
91+
return units
92+
93+
def _quantify_variable(var, units):
94+
new_data = _array_attach_units(var.data, units, convert_from=None)
95+
new_var = Variable(dims=var.dims, data=new_data,
96+
attrs=var.attrs)
97+
new_var.attrs['units'] = str(var.data.units)
98+
return new_var
7199

72100
def _dequantify_variable(var):
73101
new_var = Variable(dims=var.dims, data=var.data.magnitude,
@@ -128,24 +156,9 @@ def quantify(self, units=None, unit_registry=None, registry_kwargs=None):
128156
raise ValueError(f"Cannot attach unit {units} to quantity: data "
129157
f"already has units {self.da.data.units}")
130158

131-
if unit_registry is None:
132-
if registry_kwargs is None:
133-
registry_kwargs = {}
134-
registry_kwargs.update(force_ndarray=True)
135-
# TODO should this registry object then be stored somewhere global?
136-
unit_registry = pint.UnitRegistry(**registry_kwargs)
137-
138-
if units is None:
139-
# TODO option to read and decode units according to CF conventions (see MetPy)?
140-
attr_units = self.da.attrs['units']
141-
units = unit_registry.parse_expression(attr_units)
142-
elif isinstance(units, Unit):
143-
# TODO do we have to check what happens if someone passes a Unit instance
144-
# without creating a unit registry?
145-
# TODO and what happens if they pass in a Unit from a different registry
146-
pass
147-
else:
148-
units = unit_registry.Unit(units)
159+
registry = _get_registry(unit_registry, registry_kwargs)
160+
161+
units = _decide_units(units, registry, self.da.attrs)
149162

150163
quantity = _array_attach_units(self.da.data, units, convert_from=None)
151164

@@ -240,12 +253,20 @@ class PintDatasetAccessor:
240253
def __init__(self, ds):
241254
self.ds = ds
242255

243-
def quantify(self, unit_registry=None, decode_cf=False):
244-
quantified_vars = {name: da.pint.quantify(unit_registry=unit_registry,
245-
decode_cf=decode_cf)
246-
for name, da in self.ds.items()}
247-
return Dataset(quantified_vars, attrs=self.ds.attrs,
248-
encoding=self.ds.encoding)
256+
def quantify(self, unit_registry=None, registry_kwargs=None):
257+
258+
registry = _get_registry(unit_registry, registry_kwargs)
259+
260+
var_units = [_decide_units(None, registry, var.attrs)
261+
for var in self.ds.data_vars]
262+
263+
new_vars = {name: _quantify_variable(var, units)
264+
for name, var, units in zip(self.ds.data_vars(), var_units)}
265+
266+
# TODO should also quantify coordinates (once explicit indexes ready)
267+
# TODO should we (temporarily) remove the attrs here so that they don't become inconsistent?
268+
return Dataset(data_vars=new_vars, coords=self.coords,
269+
attrs=self.ds.attrs)
249270

250271
def dequantify(self):
251272
dequantified_vars = {name: da.pint.to_base_units()
@@ -258,7 +279,9 @@ def to_base_units(self):
258279
for name, da in self.ds.items()}
259280
return Dataset(base_vars, attrs=self.ds.attrs, encoding=self.ds.encoding)
260281

261-
# TODO way to change every variable in ds to be expressed in a new units system?
282+
# TODO unsure if the upstream capability exists in pint for this yet.
283+
def to_system(self, system):
284+
raise NotImplementedError
262285

263286
def sel(self, indexers=None, method=None, tolerance=None, drop=False,
264287
**indexers_kwargs):

0 commit comments

Comments
 (0)