@@ -298,6 +298,100 @@ def func(x, /, *args, func_str=func_str, **kwargs):
298
298
299
299
setattr (mod , func_str , func )
300
300
301
+ # strip_unit_input_output_ufuncs = ["isnan", "isinf", "isfinite", "signbit", "sign"]
302
+ # matching_input_bare_output_ufuncs = [
303
+ # "equal",
304
+ # "greater",
305
+ # "greater_equal",
306
+ # "less",
307
+ # "less_equal",
308
+ # "not_equal",
309
+ # ]
310
+ # matching_input_set_units_output_ufuncs = {"arctan2": "radian"}
311
+ # set_units_ufuncs = {
312
+ # "cumprod": ("", ""),
313
+ # "arccos": ("", "radian"),
314
+ # "arcsin": ("", "radian"),
315
+ # "arctan": ("", "radian"),
316
+ # "arccosh": ("", "radian"),
317
+ # "arcsinh": ("", "radian"),
318
+ # "arctanh": ("", "radian"),
319
+ # "exp": ("", ""),
320
+ # "expm1": ("", ""),
321
+ # "exp2": ("", ""),
322
+ # "log": ("", ""),
323
+ # "log10": ("", ""),
324
+ # "log1p": ("", ""),
325
+ # "log2": ("", ""),
326
+ # "sin": ("radian", ""),
327
+ # "cos": ("radian", ""),
328
+ # "tan": ("radian", ""),
329
+ # "sinh": ("radian", ""),
330
+ # "cosh": ("radian", ""),
331
+ # "tanh": ("radian", ""),
332
+ # "radians": ("degree", "radian"),
333
+ # "degrees": ("radian", "degree"),
334
+ # "deg2rad": ("degree", "radian"),
335
+ # "rad2deg": ("radian", "degree"),
336
+ # "logaddexp": ("", ""),
337
+ # "logaddexp2": ("", ""),
338
+ # }
339
+ # # TODO (#905 follow-up):
340
+ # # while this matches previous behavior, some of these have optional arguments
341
+ # that
342
+ # # should not be Quantities. This should be fixed, and tests using these optional
343
+ # # arguments should be added.
344
+ # matching_input_copy_units_output_ufuncs = [
345
+ # "compress",
346
+ # "conj",
347
+ # "conjugate",
348
+ # "copy",
349
+ # "diagonal",
350
+ # "max",
351
+ # "mean",
352
+ # "min",
353
+ # "ptp",
354
+ # "ravel",
355
+ # "repeat",
356
+ # "reshape",
357
+ # "round",
358
+ # "squeeze",
359
+ # "swapaxes",
360
+ # "take",
361
+ # "trace",
362
+ # "transpose",
363
+ # "roll",
364
+ # "ceil",
365
+ # "floor",
366
+ # "hypot",
367
+ # "rint",
368
+ # "copysign",
369
+ # "nextafter",
370
+ # "trunc",
371
+ # "absolute",
372
+ # "positive",
373
+ # "negative",
374
+ # "maximum",
375
+ # "minimum",
376
+ # "fabs",
377
+ # ]
378
+ # copy_units_output_ufuncs = ["ldexp", "fmod", "mod", "remainder"]
379
+ # op_units_output_ufuncs = {
380
+ # "var": "square",
381
+ # "multiply": "mul",
382
+ # "true_divide": "div",
383
+ # "divide": "div",
384
+ # "floor_divide": "div",
385
+ # "sqrt": "sqrt",
386
+ # "cbrt": "cbrt",
387
+ # "square": "square",
388
+ # "reciprocal": "reciprocal",
389
+ # "std": "sum",
390
+ # "sum": "sum",
391
+ # "cumsum": "sum",
392
+ # "matmul": "mul",
393
+ # }
394
+
301
395
elementwise_one_array = [
302
396
"abs" ,
303
397
"acos" ,
@@ -393,6 +487,28 @@ def fun(x1, x2, /, *args, func_str=func_str, **kwargs):
393
487
394
488
setattr (mod , func_str , fun )
395
489
490
+ def get_linalg_fun (func_str ):
491
+ def linalg_fun (x1 , x2 , / , ** kwargs ):
492
+ x1 = asarray (x1 )
493
+ x2 = asarray (x2 )
494
+ magnitude1 = xp .asarray (x1 .magnitude , copy = True )
495
+ magnitude2 = xp .asarray (x2 .magnitude , copy = True )
496
+
497
+ xp_func = getattr (xp , func_str )
498
+ magnitude = xp_func (magnitude1 , magnitude2 , ** kwargs )
499
+ return ArrayUnitQuantity (magnitude , x1 .units * x2 .units )
500
+
501
+ return linalg_fun
502
+
503
+ linalg_names = ["matmul" , "tensordot" , "vecdot" ]
504
+ for name in linalg_names :
505
+ setattr (mod , name , get_linalg_fun (name ))
506
+
507
+ def matrix_transpose (x ):
508
+ return x .mT
509
+
510
+ mod .matrix_transpose = matrix_transpose
511
+
396
512
# Handle functions with output unit defined by operation
397
513
398
514
# output_unit="sum":
0 commit comments