Skip to content

Commit 3ef40b9

Browse files
committed
dataset.quantify
1 parent 2e5d5df commit 3ef40b9

File tree

2 files changed

+130
-33
lines changed

2 files changed

+130
-33
lines changed

pintxarray/accessors.py

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,11 @@ def _quantify_variable(var, units):
9494
new_data = _array_attach_units(var.data, units, convert_from=None)
9595
new_var = Variable(dims=var.dims, data=new_data,
9696
attrs=var.attrs)
97-
new_var.attrs['units'] = str(var.data.units)
9897
return new_var
9998

10099
def _dequantify_variable(var):
101100
new_var = Variable(dims=var.dims, data=var.data.magnitude,
102-
attrs=var.attrs)
101+
attrs=var.attrs)
103102
new_var.attrs['units'] = str(var.data.units)
104103
return new_var
105104

@@ -119,9 +118,9 @@ def quantify(self, units=None, unit_registry=None, registry_kwargs=None):
119118
"""
120119
Attaches units to the DataArray.
121120
122-
Units can be specified as a pint.Unit or as a string, which will will
123-
be parsed by the given unit registry. If no units are specified then
124-
the units will be parsed from the `'units'` entry of the DataArray's
121+
Units can be specified as a pint.Unit or as a string, which will be
122+
parsed by the given unit registry. If no units are specified then the
123+
units will be parsed from the `'units'` entry of the DataArray's
125124
`.attrs`. Will raise a ValueError if the DataArray already contains a
126125
unit-aware array.
127126
@@ -193,10 +192,6 @@ def dequantify(self):
193192
def magnitude(self):
194193
return self.da.data.magnitude
195194

196-
@magnitude.setter
197-
def magnitude(self, da):
198-
self.da = DataArray(dim=self.da.dims, data=da.data,
199-
coords=self.da.coords, attrs=self.da.attrs)
200195
@property
201196
def units(self):
202197
return self.da.data.units
@@ -250,34 +245,76 @@ def loc(self):
250245

251246
@register_dataset_accessor("pint")
252247
class PintDatasetAccessor:
248+
"""
249+
Access methods for DataArrays with units using Pint.
250+
251+
Methods and attributes can be accessed through the `.pint` attribute.
252+
"""
253253
def __init__(self, ds):
254254
self.ds = ds
255255

256-
def quantify(self, unit_registry=None, registry_kwargs=None):
256+
def quantify(self, units=None, unit_registry=None, registry_kwargs=None):
257+
"""
258+
Attaches units to each variable in the Dataset.
259+
260+
Units can be specified as a pint.Unit or as a string, which will
261+
be parsed by the given unit registry. If no units are specified then
262+
the units will be parsed from the `'units'` entry of the DataArray's
263+
`.attrs`. Will raise a ValueError if any of the DataArrays already
264+
contain a unit-aware array.
265+
266+
Parameters
267+
----------
268+
units : mapping from variable names to pint.Unit or str, optional
269+
Physical units to use for particular DataArrays in this Dataset. If
270+
not provided, will try to read them from
271+
`Dataset[var].attrs['units']` using pint's parser.
272+
unit_registry : `pint.UnitRegistry`, optional
273+
Unit registry to be used for the units attached to each DataArray
274+
in this Dataset. If not given then a default registry will be
275+
created.
276+
registry_kwargs : dict, optional
277+
Keyword arguments to be passed to `pint.UnitRegistry`.
278+
279+
Returns
280+
-------
281+
quantified - Dataset whose variables will now contain Quantity
282+
arrays with units.
283+
"""
284+
285+
for var in self.ds.data_vars:
286+
if isinstance(self.ds[var].data, Quantity):
287+
raise ValueError(f"Cannot attach unit to quantity: data "
288+
f"variable {var} already has units "
289+
f"{self.ds[var].data.units}")
257290

258291
registry = _get_registry(unit_registry, registry_kwargs)
259292

260-
var_units = [_decide_units(None, registry, var.attrs)
261-
for var in self.ds.data_vars]
293+
if units is None:
294+
units = {name: None for name in self.ds}
262295

263-
new_vars = {name: _quantify_variable(var, units)
264-
for name, var, units in zip(self.ds.data_vars(), var_units)}
296+
units = {name: _decide_units(units.get(name, None), registry, var.attrs)
297+
for name, var in self.ds.data_vars.items()}
298+
299+
quantified_vars = {name: _quantify_variable(var, units[name])
300+
for name, var in self.ds.data_vars.items()}
265301

266302
# TODO should also quantify coordinates (once explicit indexes ready)
267303
# TODO should we (temporarily) remove the attrs here so that they don't become inconsistent?
268-
return Dataset(data_vars=new_vars, coords=self.coords,
304+
return Dataset(data_vars=quantified_vars, coords=self.ds.coords,
269305
attrs=self.ds.attrs)
270306

271307
def dequantify(self):
272308
dequantified_vars = {name: da.pint.to_base_units()
273309
for name, da in self.ds.items()}
274-
return Dataset(dequantified_vars, attrs=self.ds.attrs,
275-
encoding=self.ds.encoding)
310+
return Dataset(dequantified_vars, coords=self.ds.coords,
311+
attrs=self.ds.attrs)
276312

277313
def to_base_units(self):
278314
base_vars = {name: da.pint.to_base_units()
279315
for name, da in self.ds.items()}
280-
return Dataset(base_vars, attrs=self.ds.attrs, encoding=self.ds.encoding)
316+
return Dataset(base_vars, coords=self.ds.coords,
317+
attrs=self.ds.attrs)
281318

282319
# TODO unsure if the upstream capability exists in pint for this yet.
283320
def to_system(self, system):

pintxarray/tests/test_accessors.py

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,27 @@ def example_quantity_da():
3737

3838

3939
class TestQuantifyDataArray:
40+
def test_attach_units_from_str(self, example_unitless_da):
41+
orig = example_unitless_da
42+
result = orig.pint.quantify('m')
43+
assert_array_equal(result.data.magnitude, orig.data)
44+
# TODO better comparisons for when you can't access the unit_registry?
45+
assert str(result.data.units) == 'meter'
46+
4047
def test_attach_units_given_registry(self, example_unitless_da):
4148
orig = example_unitless_da
4249
ureg = UnitRegistry(force_ndarray=True)
4350
result = orig.pint.quantify('m', unit_registry=ureg)
4451
assert_array_equal(result.data.magnitude, orig.data)
4552
assert result.data.units == ureg.Unit('m')
4653

47-
def test_attach_units(self, example_unitless_da):
48-
orig = example_unitless_da
49-
result = orig.pint.quantify('m')
50-
assert_array_equal(result.data.magnitude, orig.data)
51-
# TODO better comparisons for when you can't access the unit_registry?
52-
assert str(result.data.units) == 'meter'
53-
5454
def test_attach_units_from_attrs(self, example_unitless_da):
5555
orig = example_unitless_da
5656
result = orig.pint.quantify()
5757
assert_array_equal(result.data.magnitude, orig.data)
5858
assert str(result.data.units) == 'meter'
5959

60-
def test_attach_unit_class(self, example_unitless_da):
60+
def test_attach_units_given_unit_objs(self, example_unitless_da):
6161
orig = example_unitless_da
6262
ureg = UnitRegistry(force_ndarray=True)
6363
result = orig.pint.quantify(ureg.Unit('m'), unit_registry=ureg)
@@ -103,6 +103,29 @@ def test_roundtrip_data(self, example_unitless_da):
103103
assert_equal(result, orig)
104104

105105

106+
@pytest.fixture()
107+
def example_unitless_ds():
108+
users = np.linspace(0, 10, 20)
109+
funds = np.logspace(0, 10, 20)
110+
t = np.arange(20)
111+
ds = xr.Dataset(data_vars={'users': (['t'], users),
112+
'funds': (['t'], funds)},
113+
coords={"t": t})
114+
ds['users'].attrs['units'] = ''
115+
ds['funds'].attrs['units'] = 'pound'
116+
return ds
117+
118+
@pytest.fixture()
119+
def example_quantity_ds():
120+
users = np.linspace(0, 10, 20) * unit_registry.dimensionless
121+
funds = np.logspace(0, 10, 20) * unit_registry.pound
122+
t = np.arange(20)
123+
ds = xr.Dataset(data_vars={'users': (['t'], users),
124+
'funds': (['t'], funds)},
125+
coords={"t": t})
126+
return ds
127+
128+
106129
@pytest.mark.skip(reason="Not yet implemented")
107130
class TestPropertiesDataArray:
108131
def test_units(self):
@@ -121,16 +144,53 @@ def test_units(self):
121144
...
122145

123146

124-
@pytest.mark.skip(reason="Not yet implemented")
125147
class TestQuantifyDataSet:
126-
def test_attach_units(self):
127-
...
148+
def test_attach_units_from_str(self, example_unitless_ds):
149+
orig = example_unitless_ds
150+
result = orig.pint.quantify()
151+
assert_array_equal(result['users'].data.magnitude,
152+
orig['users'].data)
153+
assert str(result['users'].data.units) == 'dimensionless'
154+
155+
def test_attach_units_given_registry(self, example_unitless_ds):
156+
orig = example_unitless_ds
157+
orig['users'].attrs.clear()
158+
result = orig.pint.quantify({'users': 'dimensionless'},
159+
unit_registry=unit_registry)
160+
assert_array_equal(result['users'].data.magnitude,
161+
orig['users'].data)
162+
assert str(result['users'].data.units) == 'dimensionless'
163+
164+
def test_attach_units_from_attrs(self, example_unitless_ds):
165+
orig = example_unitless_ds
166+
orig['users'].attrs.clear()
167+
result = orig.pint.quantify({'users': 'dimensionless'})
168+
assert_array_equal(result['users'].data.magnitude,
169+
orig['users'].data)
170+
assert str(result['users'].data.units) == 'dimensionless'
171+
172+
def test_attach_units_given_unit_objs(self, example_unitless_ds):
173+
orig = example_unitless_ds
174+
orig['users'].attrs.clear()
175+
dimensionless = unit_registry.Unit('dimensionless')
176+
result = orig.pint.quantify({'users': dimensionless})
177+
assert_array_equal(result['users'].data.magnitude,
178+
orig['users'].data)
179+
assert str(result['users'].data.units) == 'dimensionless'
180+
181+
def test_error_when_already_units(self, example_quantity_ds):
182+
with raises_regex(ValueError, "already has units"):
183+
example_quantity_ds.pint.quantify()
128184

129-
def test_attach_str_units(self):
130-
...
185+
def test_error_on_nonsense_units(self, example_unitless_ds):
186+
ds = example_unitless_ds
187+
with pytest.raises(UndefinedUnitError):
188+
ds.pint.quantify(units={'users': 'aecjhbav'})
131189

132-
def test_attach_units_from_registry(self):
133-
...
190+
191+
@pytest.mark.skip(reason="Not yet implemented")
192+
class TestDequantifyDataSet:
193+
...
134194

135195

136196
@pytest.mark.skip(reason="Not yet implemented")

0 commit comments

Comments
 (0)