Skip to content

Commit 7f9f932

Browse files
committed
elementwise
1 parent af8a3e2 commit 7f9f932

File tree

1 file changed

+154
-0
lines changed

1 file changed

+154
-0
lines changed

src/pint_array/__init__.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
import types
1212
from typing import Generic
1313

14+
from array_api_compat import is_array_api_obj
1415
from pint import Quantity
1516
from pint.facets.plain import MagnitudeT, PlainQuantity
17+
from pint.util import iterable, sized
1618

1719
__version__ = "0.0.1.dev0"
1820
__all__ = ["__version__", "pint_namespace"]
@@ -211,6 +213,24 @@ def asarray(obj, /, *, units=None, dtype=None, device=None, copy=None):
211213

212214
mod.asarray = asarray
213215

216+
creation_functions = [
217+
"arange",
218+
"empty",
219+
"eye",
220+
"from_dlpack",
221+
"full",
222+
"linspace",
223+
"ones",
224+
"zeros",
225+
]
226+
for func_str in creation_functions:
227+
228+
def fun(*args, func_str=func_str, units=None, **kwargs):
229+
magnitude = getattr(xp, func_str)(*args, **kwargs)
230+
return ArrayUnitQuantity(magnitude, units)
231+
232+
setattr(mod, func_str, fun)
233+
214234
## Data Type Functions and Data Types ##
215235
dtype_fun_names = ["can_cast", "finfo", "iinfo", "isdtype"]
216236
dtype_names = [
@@ -280,6 +300,140 @@ def func(x, /, *args, func_str=func_str, **kwargs):
280300

281301
setattr(mod, func_str, func)
282302

303+
elementwise_one_array = [
304+
"abs",
305+
"acos",
306+
"acosh",
307+
"asin",
308+
"asinh",
309+
"atan",
310+
"atanh",
311+
"bitwise_invert",
312+
"ceil",
313+
"conj",
314+
"cos",
315+
"cosh",
316+
"exp",
317+
"expm1",
318+
"floor",
319+
"imag",
320+
"isfinite",
321+
"isinf",
322+
"isnan",
323+
"log",
324+
"log1p",
325+
"log2",
326+
"log10",
327+
"logical_not",
328+
"negative",
329+
"positive",
330+
"real",
331+
"round",
332+
"sign",
333+
"signbit",
334+
"sin",
335+
"sinh",
336+
"square",
337+
"sqrt",
338+
"tan",
339+
"tanh",
340+
"trunc",
341+
]
342+
for func_str in elementwise_one_array:
343+
344+
def fun(x, /, *args, func_str=func_str, **kwargs):
345+
x = asarray(x)
346+
magnitude = xp.asarray(x.magnitude, copy=True)
347+
magnitude = getattr(xp, func_str)(x, *args, **kwargs)
348+
return ArrayUnitQuantity(magnitude, x.units)
349+
350+
setattr(mod, func_str, fun)
351+
352+
def _is_quantity(obj):
353+
"""Test for _units and _magnitude attrs.
354+
355+
This is done in place of isinstance(Quantity, arg),
356+
which would cause a circular import.
357+
358+
Parameters
359+
----------
360+
obj : Object
361+
362+
Returns
363+
-------
364+
bool
365+
"""
366+
return hasattr(obj, "_units") and hasattr(obj, "_magnitude")
367+
368+
def _is_sequence_with_quantity_elements(obj):
369+
"""Test for sequences of quantities.
370+
371+
Parameters
372+
----------
373+
obj : object
374+
375+
Returns
376+
-------
377+
True if obj is a sequence and at least one element is a Quantity;
378+
False otherwise
379+
"""
380+
if is_array_api_obj(obj) and not obj.dtype.hasobject:
381+
# If obj is an array, avoid looping on all elements
382+
# if dtype does not have objects
383+
return False
384+
return (
385+
iterable(obj)
386+
and sized(obj)
387+
and not isinstance(obj, str)
388+
and any(_is_quantity(item) for item in obj)
389+
)
390+
391+
elementwise_two_arrays = [
392+
"add",
393+
"atan2",
394+
"bitwise_and",
395+
"bitwise_left_shift",
396+
"bitwise_or",
397+
"bitwise_right_shift",
398+
"bitwise_xor",
399+
"copysign",
400+
"divide",
401+
"equal",
402+
"floor_divide",
403+
"greater",
404+
"greater_equal",
405+
"hypot",
406+
"less",
407+
"less_equal",
408+
"logaddexp",
409+
"logical_and",
410+
"logical_or",
411+
"logical_xor",
412+
"maximum",
413+
"minimum",
414+
"multiply",
415+
"not_equal",
416+
"pow",
417+
"remainder",
418+
"subtract",
419+
]
420+
for func_str in elementwise_two_arrays:
421+
422+
def fun(x1, x2, /, *args, func_str=func_str, **kwargs):
423+
x1 = asarray(x1)
424+
x2 = asarray(x2)
425+
426+
units = x1.units
427+
428+
x1_magnitude = xp.asarray(x1.magnitude, copy=True)
429+
x2_magnitude = x2.m_as(units)
430+
431+
xp_func = getattr(xp, func_str)
432+
magnitude = xp_func(x1_magnitude, x2_magnitude, *args, **kwargs)
433+
return ArrayUnitQuantity(magnitude, units)
434+
435+
setattr(mod, func_str, fun)
436+
283437
# Handle functions with output unit defined by operation
284438

285439
# output_unit="sum":

0 commit comments

Comments
 (0)