Skip to content

Commit 1a1427c

Browse files
committed
add linear algebra funcs
1 parent babbc9a commit 1a1427c

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed

src/pint_array/__init__.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,100 @@ def func(x, /, *args, func_str=func_str, **kwargs):
298298

299299
setattr(mod, func_str, func)
300300

301+
# strip_unit_input_output_ufuncs = ["isnan", "isinf", "isfinite", "signbit", "sign"]
302+
# matching_input_bare_output_ufuncs = [
303+
# "equal",
304+
# "greater",
305+
# "greater_equal",
306+
# "less",
307+
# "less_equal",
308+
# "not_equal",
309+
# ]
310+
# matching_input_set_units_output_ufuncs = {"arctan2": "radian"}
311+
# set_units_ufuncs = {
312+
# "cumprod": ("", ""),
313+
# "arccos": ("", "radian"),
314+
# "arcsin": ("", "radian"),
315+
# "arctan": ("", "radian"),
316+
# "arccosh": ("", "radian"),
317+
# "arcsinh": ("", "radian"),
318+
# "arctanh": ("", "radian"),
319+
# "exp": ("", ""),
320+
# "expm1": ("", ""),
321+
# "exp2": ("", ""),
322+
# "log": ("", ""),
323+
# "log10": ("", ""),
324+
# "log1p": ("", ""),
325+
# "log2": ("", ""),
326+
# "sin": ("radian", ""),
327+
# "cos": ("radian", ""),
328+
# "tan": ("radian", ""),
329+
# "sinh": ("radian", ""),
330+
# "cosh": ("radian", ""),
331+
# "tanh": ("radian", ""),
332+
# "radians": ("degree", "radian"),
333+
# "degrees": ("radian", "degree"),
334+
# "deg2rad": ("degree", "radian"),
335+
# "rad2deg": ("radian", "degree"),
336+
# "logaddexp": ("", ""),
337+
# "logaddexp2": ("", ""),
338+
# }
339+
# # TODO (#905 follow-up):
340+
# # while this matches previous behavior, some of these have optional arguments
341+
# that
342+
# # should not be Quantities. This should be fixed, and tests using these optional
343+
# # arguments should be added.
344+
# matching_input_copy_units_output_ufuncs = [
345+
# "compress",
346+
# "conj",
347+
# "conjugate",
348+
# "copy",
349+
# "diagonal",
350+
# "max",
351+
# "mean",
352+
# "min",
353+
# "ptp",
354+
# "ravel",
355+
# "repeat",
356+
# "reshape",
357+
# "round",
358+
# "squeeze",
359+
# "swapaxes",
360+
# "take",
361+
# "trace",
362+
# "transpose",
363+
# "roll",
364+
# "ceil",
365+
# "floor",
366+
# "hypot",
367+
# "rint",
368+
# "copysign",
369+
# "nextafter",
370+
# "trunc",
371+
# "absolute",
372+
# "positive",
373+
# "negative",
374+
# "maximum",
375+
# "minimum",
376+
# "fabs",
377+
# ]
378+
# copy_units_output_ufuncs = ["ldexp", "fmod", "mod", "remainder"]
379+
# op_units_output_ufuncs = {
380+
# "var": "square",
381+
# "multiply": "mul",
382+
# "true_divide": "div",
383+
# "divide": "div",
384+
# "floor_divide": "div",
385+
# "sqrt": "sqrt",
386+
# "cbrt": "cbrt",
387+
# "square": "square",
388+
# "reciprocal": "reciprocal",
389+
# "std": "sum",
390+
# "sum": "sum",
391+
# "cumsum": "sum",
392+
# "matmul": "mul",
393+
# }
394+
301395
elementwise_one_array = [
302396
"abs",
303397
"acos",
@@ -393,6 +487,28 @@ def fun(x1, x2, /, *args, func_str=func_str, **kwargs):
393487

394488
setattr(mod, func_str, fun)
395489

490+
def get_linalg_fun(func_str):
491+
def linalg_fun(x1, x2, /, **kwargs):
492+
x1 = asarray(x1)
493+
x2 = asarray(x2)
494+
magnitude1 = xp.asarray(x1.magnitude, copy=True)
495+
magnitude2 = xp.asarray(x2.magnitude, copy=True)
496+
497+
xp_func = getattr(xp, func_str)
498+
magnitude = xp_func(magnitude1, magnitude2, **kwargs)
499+
return ArrayUnitQuantity(magnitude, x1.units * x2.units)
500+
501+
return linalg_fun
502+
503+
linalg_names = ["matmul", "tensordot", "vecdot"]
504+
for name in linalg_names:
505+
setattr(mod, name, get_linalg_fun(name))
506+
507+
def matrix_transpose(x):
508+
return x.mT
509+
510+
mod.matrix_transpose = matrix_transpose
511+
396512
# Handle functions with output unit defined by operation
397513

398514
# output_unit="sum":

0 commit comments

Comments
 (0)