1
1
"""
2
- pint_array
3
- ~~~~~~~~~~
2
+ pint_array
3
+ ~~~~~~~~~~
4
4
5
- Pint interoperability with array API standard arrays.
5
+ Pint interoperability with array API standard arrays.
6
6
"""
7
7
8
8
from __future__ import annotations
9
9
10
- from typing import Generic
11
- import types
12
10
import textwrap
11
+ import types
12
+ from typing import Generic
13
13
14
- from pint .facets .plain import MagnitudeT , PlainQuantity
15
14
from pint import Quantity
15
+ from pint .facets .plain import MagnitudeT , PlainQuantity
16
16
17
17
__version__ = "0.0.1.dev0"
18
- __all__ = ["pint_namespace " , "__version__ " ]
18
+ __all__ = ["__version__ " , "pint_namespace " ]
19
19
20
20
21
21
def pint_namespace (xp ):
22
-
23
- mod = types .ModuleType (f'pint({ xp .__name__ } )' )
22
+ mod = types .ModuleType (f"pint({ xp .__name__ } )" )
24
23
25
24
class ArrayQuantity (Generic [MagnitudeT ], PlainQuantity [MagnitudeT ]):
26
25
def __init__ (self , * args , ** kwargs ):
@@ -56,14 +55,13 @@ def size(self):
56
55
return self ._size
57
56
58
57
def __array_namespace__ (self , api_version = None ):
59
- if api_version is None or api_version == ' 2023.12' :
58
+ if api_version is None or api_version == " 2023.12" :
60
59
return mod
61
- else :
62
- raise NotImplementedError ()
63
-
60
+ raise NotImplementedError ()
61
+
64
62
def _call_super_method (self , method_name , * args , ** kwargs ):
65
63
method = getattr (self .magnitude , method_name )
66
- args = [getattr (arg , ' magnitude' , arg ) for arg in args ]
64
+ args = [getattr (arg , " magnitude" , arg ) for arg in args ]
67
65
return method (* args , ** kwargs )
68
66
69
67
## Indexing ##
@@ -86,7 +84,6 @@ def _call_super_method(self, method_name, *args, **kwargs):
86
84
# self.mask[key] = getattr(other, 'mask', False)
87
85
# return self.data.__setitem__(key, getattr(other, 'data', other))
88
86
89
-
90
87
## Visualization ##
91
88
def __repr__ (self ):
92
89
return (
@@ -108,7 +105,7 @@ def __repr__(self):
108
105
# def __rmatmul__(self, other):
109
106
# other = MArray(other)
110
107
# return mod.matmul(other, self)
111
-
108
+
112
109
## Attributes ##
113
110
114
111
@property
@@ -134,23 +131,31 @@ def to_device(self, device, /, *, stream=None):
134
131
class ArrayUnitQuantity (ArrayQuantity , Quantity ):
135
132
pass
136
133
137
-
138
134
## Methods ##
139
135
140
136
# Methods that return the result of a unary operation as an array
141
- unary_names = (
142
- ['__abs__' , '__floordiv__' , '__invert__' , '__neg__' , '__pos__' , '__ceil__' ]
143
- )
137
+ unary_names = [
138
+ "__abs__" ,
139
+ "__floordiv__" ,
140
+ "__invert__" ,
141
+ "__neg__" ,
142
+ "__pos__" ,
143
+ "__ceil__" ,
144
+ ]
144
145
for name in unary_names :
146
+
145
147
def fun (self , name = name ):
146
148
return ArrayUnitQuantity (self ._call_super_method (name ), self .units )
149
+
147
150
setattr (ArrayQuantity , name , fun )
148
151
149
152
# Methods that return the result of a unary operation as a Python scalar
150
- unary_names_py = [' __bool__' , ' __complex__' , ' __float__' , ' __index__' , ' __int__' ]
153
+ unary_names_py = [" __bool__" , " __complex__" , " __float__" , " __index__" , " __int__" ]
151
154
for name in unary_names_py :
155
+
152
156
def fun (self , name = name ):
153
157
return self ._call_super_method (name )
158
+
154
159
setattr (ArrayQuantity , name , fun )
155
160
156
161
# # Methods that return the result of an elementwise binary operation
@@ -186,20 +191,34 @@ def asarray(obj, /, *, units=None, dtype=None, device=None, copy=None):
186
191
if device is not None :
187
192
raise NotImplementedError ("`device` argument is not implemented" )
188
193
189
- magnitude = getattr (obj , ' magnitude' , obj )
194
+ magnitude = getattr (obj , " magnitude" , obj )
190
195
magnitude = xp .asarray (magnitude , dtype = dtype , device = device , copy = copy )
191
196
192
- units = getattr (obj , ' units' , None ) if units is None else units
197
+ units = getattr (obj , " units" , None ) if units is None else units
193
198
194
199
return ArrayUnitQuantity (magnitude , units )
200
+
195
201
mod .asarray = asarray
196
202
197
203
## Data Type Functions and Data Types ##
198
- dtype_fun_names = ['can_cast' , 'finfo' , 'iinfo' , 'isdtype' ]
199
- dtype_names = ['bool' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8' , 'uint16' ,
200
- 'uint32' , 'uint64' , 'float32' , 'float64' , 'complex64' , 'complex128' ]
201
- inspection_fun_names = ['__array_namespace_info__' ]
202
- version_attribute_names = ['__array_api_version__' ]
204
+ dtype_fun_names = ["can_cast" , "finfo" , "iinfo" , "isdtype" ]
205
+ dtype_names = [
206
+ "bool" ,
207
+ "int8" ,
208
+ "int16" ,
209
+ "int32" ,
210
+ "int64" ,
211
+ "uint8" ,
212
+ "uint16" ,
213
+ "uint32" ,
214
+ "uint64" ,
215
+ "float32" ,
216
+ "float64" ,
217
+ "complex64" ,
218
+ "complex128" ,
219
+ ]
220
+ inspection_fun_names = ["__array_namespace_info__" ]
221
+ version_attribute_names = ["__array_api_version__" ]
203
222
for name in (
204
223
dtype_fun_names + dtype_names + inspection_fun_names + version_attribute_names
205
224
):
@@ -211,6 +230,7 @@ def astype(x, dtype, /, *, copy=True, device=None):
211
230
x = asarray (x )
212
231
magnitude = xp .astype (x .magnitude , dtype , copy = copy , device = device )
213
232
return ArrayUnitQuantity (magnitude , x .units )
233
+
214
234
mod .astype = astype
215
235
216
236
# Handle functions that ignore units on input and output
@@ -223,12 +243,14 @@ def astype(x, dtype, /, *, copy=True, device=None):
223
243
"argmax" ,
224
244
"nonzero" ,
225
245
):
246
+
226
247
def func (x , / , * args , func_str = func_str , ** kwargs ):
227
248
x = asarray (x )
228
249
magnitude = xp .asarray (x .magnitude , copy = True )
229
250
xp_func = getattr (xp , func_str )
230
251
magnitude = xp_func (x , * args , ** kwargs )
231
252
return ArrayUnitQuantity (magnitude , None )
253
+
232
254
setattr (mod , func_str , func )
233
255
234
256
# Handle functions with output unit defined by operation
@@ -240,6 +262,7 @@ def func(x, /, *args, func_str=func_str, **kwargs):
240
262
"cumulative_sum" ,
241
263
"sum" ,
242
264
):
265
+
243
266
def func (x , / , * args , func_str = func_str , ** kwargs ):
244
267
x = asarray (x )
245
268
magnitude = xp .asarray (x .magnitude , copy = True )
@@ -248,10 +271,11 @@ def func(x, /, *args, func_str=func_str, **kwargs):
248
271
magnitude = xp_func (x , * args , ** kwargs )
249
272
units = (1 * units + 1 * units ).units
250
273
return ArrayUnitQuantity (magnitude , units )
274
+
251
275
setattr (mod , func_str , func )
252
276
253
- # output_unit="variance":
254
- # square of `x.units`,
277
+ # output_unit="variance":
278
+ # square of `x.units`,
255
279
# unless non-multiplicative, which raises `OffsetUnitCalculusError`
256
280
def var (x , / , * , axis = None , correction = 0.0 , keepdims = False ):
257
281
x = asarray (x )
@@ -260,6 +284,7 @@ def var(x, /, *, axis=None, correction=0.0, keepdims=False):
260
284
magnitude = xp .var (x , axis = axis , correction = correction , keepdims = keepdims )
261
285
units = ((1 * units + 1 * units ) ** 2 ).units
262
286
return ArrayUnitQuantity (magnitude , units )
287
+
263
288
mod .var = var
264
289
265
290
# "mul": product of all units in `all_args`
0 commit comments