Skip to content

Commit d6d5d8c

Browse files
committed
run pre-commit
1 parent 14d880e commit d6d5d8c

File tree

3 files changed

+67
-42
lines changed

3 files changed

+67
-42
lines changed

.github/dependabot.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ updates:
88
groups:
99
actions:
1010
patterns:
11-
- "*"
11+
- "*"

src/pint_array/__init__.py

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
"""
2-
pint_array
3-
~~~~~~~~~~
2+
pint_array
3+
~~~~~~~~~~
44
5-
Pint interoperability with array API standard arrays.
5+
Pint interoperability with array API standard arrays.
66
"""
77

88
from __future__ import annotations
99

10-
from typing import Generic
11-
import types
1210
import textwrap
11+
import types
12+
from typing import Generic
1313

14-
from pint.facets.plain import MagnitudeT, PlainQuantity
1514
from pint import Quantity
15+
from pint.facets.plain import MagnitudeT, PlainQuantity
1616

1717
__version__ = "0.0.1.dev0"
18-
__all__ = ["pint_namespace", "__version__"]
18+
__all__ = ["__version__", "pint_namespace"]
1919

2020

2121
def pint_namespace(xp):
22-
23-
mod = types.ModuleType(f'pint({xp.__name__})')
22+
mod = types.ModuleType(f"pint({xp.__name__})")
2423

2524
class ArrayQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
2625
def __init__(self, *args, **kwargs):
@@ -56,14 +55,13 @@ def size(self):
5655
return self._size
5756

5857
def __array_namespace__(self, api_version=None):
59-
if api_version is None or api_version == '2023.12':
58+
if api_version is None or api_version == "2023.12":
6059
return mod
61-
else:
62-
raise NotImplementedError()
63-
60+
raise NotImplementedError()
61+
6462
def _call_super_method(self, method_name, *args, **kwargs):
6563
method = getattr(self.magnitude, method_name)
66-
args = [getattr(arg, 'magnitude', arg) for arg in args]
64+
args = [getattr(arg, "magnitude", arg) for arg in args]
6765
return method(*args, **kwargs)
6866

6967
## Indexing ##
@@ -86,7 +84,6 @@ def _call_super_method(self, method_name, *args, **kwargs):
8684
# self.mask[key] = getattr(other, 'mask', False)
8785
# return self.data.__setitem__(key, getattr(other, 'data', other))
8886

89-
9087
## Visualization ##
9188
def __repr__(self):
9289
return (
@@ -108,7 +105,7 @@ def __repr__(self):
108105
# def __rmatmul__(self, other):
109106
# other = MArray(other)
110107
# return mod.matmul(other, self)
111-
108+
112109
## Attributes ##
113110

114111
@property
@@ -134,23 +131,31 @@ def to_device(self, device, /, *, stream=None):
134131
class ArrayUnitQuantity(ArrayQuantity, Quantity):
135132
pass
136133

137-
138134
## Methods ##
139135

140136
# Methods that return the result of a unary operation as an array
141-
unary_names = (
142-
['__abs__', '__floordiv__', '__invert__', '__neg__', '__pos__', '__ceil__']
143-
)
137+
unary_names = [
138+
"__abs__",
139+
"__floordiv__",
140+
"__invert__",
141+
"__neg__",
142+
"__pos__",
143+
"__ceil__",
144+
]
144145
for name in unary_names:
146+
145147
def fun(self, name=name):
146148
return ArrayUnitQuantity(self._call_super_method(name), self.units)
149+
147150
setattr(ArrayQuantity, name, fun)
148151

149152
# Methods that return the result of a unary operation as a Python scalar
150-
unary_names_py = ['__bool__', '__complex__', '__float__', '__index__', '__int__']
153+
unary_names_py = ["__bool__", "__complex__", "__float__", "__index__", "__int__"]
151154
for name in unary_names_py:
155+
152156
def fun(self, name=name):
153157
return self._call_super_method(name)
158+
154159
setattr(ArrayQuantity, name, fun)
155160

156161
# # Methods that return the result of an elementwise binary operation
@@ -186,20 +191,34 @@ def asarray(obj, /, *, units=None, dtype=None, device=None, copy=None):
186191
if device is not None:
187192
raise NotImplementedError("`device` argument is not implemented")
188193

189-
magnitude = getattr(obj, 'magnitude', obj)
194+
magnitude = getattr(obj, "magnitude", obj)
190195
magnitude = xp.asarray(magnitude, dtype=dtype, device=device, copy=copy)
191196

192-
units = getattr(obj, 'units', None) if units is None else units
197+
units = getattr(obj, "units", None) if units is None else units
193198

194199
return ArrayUnitQuantity(magnitude, units)
200+
195201
mod.asarray = asarray
196202

197203
## Data Type Functions and Data Types ##
198-
dtype_fun_names = ['can_cast', 'finfo', 'iinfo', 'isdtype']
199-
dtype_names = ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16',
200-
'uint32', 'uint64', 'float32', 'float64', 'complex64', 'complex128']
201-
inspection_fun_names = ['__array_namespace_info__']
202-
version_attribute_names = ['__array_api_version__']
204+
dtype_fun_names = ["can_cast", "finfo", "iinfo", "isdtype"]
205+
dtype_names = [
206+
"bool",
207+
"int8",
208+
"int16",
209+
"int32",
210+
"int64",
211+
"uint8",
212+
"uint16",
213+
"uint32",
214+
"uint64",
215+
"float32",
216+
"float64",
217+
"complex64",
218+
"complex128",
219+
]
220+
inspection_fun_names = ["__array_namespace_info__"]
221+
version_attribute_names = ["__array_api_version__"]
203222
for name in (
204223
dtype_fun_names + dtype_names + inspection_fun_names + version_attribute_names
205224
):
@@ -211,6 +230,7 @@ def astype(x, dtype, /, *, copy=True, device=None):
211230
x = asarray(x)
212231
magnitude = xp.astype(x.magnitude, dtype, copy=copy, device=device)
213232
return ArrayUnitQuantity(magnitude, x.units)
233+
214234
mod.astype = astype
215235

216236
# Handle functions that ignore units on input and output
@@ -223,12 +243,14 @@ def astype(x, dtype, /, *, copy=True, device=None):
223243
"argmax",
224244
"nonzero",
225245
):
246+
226247
def func(x, /, *args, func_str=func_str, **kwargs):
227248
x = asarray(x)
228249
magnitude = xp.asarray(x.magnitude, copy=True)
229250
xp_func = getattr(xp, func_str)
230251
magnitude = xp_func(x, *args, **kwargs)
231252
return ArrayUnitQuantity(magnitude, None)
253+
232254
setattr(mod, func_str, func)
233255

234256
# Handle functions with output unit defined by operation
@@ -240,6 +262,7 @@ def func(x, /, *args, func_str=func_str, **kwargs):
240262
"cumulative_sum",
241263
"sum",
242264
):
265+
243266
def func(x, /, *args, func_str=func_str, **kwargs):
244267
x = asarray(x)
245268
magnitude = xp.asarray(x.magnitude, copy=True)
@@ -248,10 +271,11 @@ def func(x, /, *args, func_str=func_str, **kwargs):
248271
magnitude = xp_func(x, *args, **kwargs)
249272
units = (1 * units + 1 * units).units
250273
return ArrayUnitQuantity(magnitude, units)
274+
251275
setattr(mod, func_str, func)
252276

253-
# output_unit="variance":
254-
# square of `x.units`,
277+
# output_unit="variance":
278+
# square of `x.units`,
255279
# unless non-multiplicative, which raises `OffsetUnitCalculusError`
256280
def var(x, /, *, axis=None, correction=0.0, keepdims=False):
257281
x = asarray(x)
@@ -260,6 +284,7 @@ def var(x, /, *, axis=None, correction=0.0, keepdims=False):
260284
magnitude = xp.var(x, axis=axis, correction=correction, keepdims=keepdims)
261285
units = ((1 * units + 1 * units) ** 2).units
262286
return ArrayUnitQuantity(magnitude, units)
287+
263288
mod.var = var
264289

265290
# "mul": product of all units in `all_args`

src/pint_array/funcs.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
pint_array.funcs
3-
~~~~~~~~~~~~~~~~
2+
pint_array.funcs
3+
~~~~~~~~~~~~~~~~
44
"""
55

66
from __future__ import annotations
@@ -64,7 +64,7 @@ def _get_first_input_units(args, kwargs=None):
6464
for arg in chain(args, kwargs.values()):
6565
if _is_quantity(arg):
6666
return arg.units
67-
elif _is_sequence_with_quantity_elements(arg):
67+
if _is_sequence_with_quantity_elements(arg):
6868
return next(arg_i.units for arg_i in arg if _is_quantity(arg_i))
6969
raise TypeError("Expected at least one Quantity; found none")
7070

@@ -80,15 +80,14 @@ def convert_arg(arg, pre_calc_units):
8080
if pre_calc_units is not None:
8181
if _is_quantity(arg):
8282
return arg.m_as(pre_calc_units)
83-
elif _is_sequence_with_quantity_elements(arg):
83+
if _is_sequence_with_quantity_elements(arg):
8484
return [convert_arg(item, pre_calc_units) for item in arg]
85-
elif arg is not None:
85+
if arg is not None:
8686
if pre_calc_units.dimensionless:
8787
return pre_calc_units._REGISTRY.Quantity(arg).m_as(pre_calc_units)
88-
elif not _is_quantity(arg) and zero_or_nan(arg, True):
88+
if not _is_quantity(arg) and zero_or_nan(arg, True):
8989
return arg
90-
else:
91-
raise DimensionalityError("dimensionless", pre_calc_units)
90+
raise DimensionalityError("dimensionless", pre_calc_units)
9291
elif _is_quantity(arg):
9392
return arg.m
9493
elif _is_sequence_with_quantity_elements(arg):
@@ -285,7 +284,8 @@ def implementation(*args, **kwargs):
285284
):
286285
# the sequence may contain different units, so fall back to element-wise
287286
return xp.asarray(
288-
[func(*func_args) for func_args in zip(*args)], dtype=object
287+
[func(*func_args) for func_args in zip(*args, strict=False)],
288+
dtype=object,
289289
)
290290

291291
first_input_units = _get_first_input_units(args, kwargs)
@@ -313,7 +313,7 @@ def implementation(*args, **kwargs):
313313
if output_unit is None:
314314
# Short circuit and return magnitude alone
315315
return result_magnitude
316-
elif output_unit == "match_input":
316+
if output_unit == "match_input":
317317
result_unit = first_input_units
318318
elif output_unit in (
319319
"sum",
@@ -734,7 +734,7 @@ def implementation(*args, **kwargs):
734734
# implement_prod_func(name)
735735

736736

737-
# # Handle mutliplicative functions separately to deal with non-multiplicative units
737+
# # Handle multiplicative functions separately to deal with non-multiplicative units
738738
# def _base_unit_if_needed(a):
739739
# if a._is_multiplicative:
740740
# return a

0 commit comments

Comments
 (0)