|
11 | 11 | import types
|
12 | 12 | from typing import Generic
|
13 | 13 |
|
| 14 | +from array_api_compat import is_array_api_obj |
14 | 15 | from pint import Quantity
|
15 | 16 | from pint.facets.plain import MagnitudeT, PlainQuantity
|
| 17 | +from pint.util import iterable, sized |
16 | 18 |
|
17 | 19 | __version__ = "0.0.1.dev0"
|
18 | 20 | __all__ = ["__version__", "pint_namespace"]
|
@@ -211,6 +213,24 @@ def asarray(obj, /, *, units=None, dtype=None, device=None, copy=None):
|
211 | 213 |
|
212 | 214 | mod.asarray = asarray
|
213 | 215 |
|
| 216 | + creation_functions = [ |
| 217 | + "arange", |
| 218 | + "empty", |
| 219 | + "eye", |
| 220 | + "from_dlpack", |
| 221 | + "full", |
| 222 | + "linspace", |
| 223 | + "ones", |
| 224 | + "zeros", |
| 225 | + ] |
| 226 | + for func_str in creation_functions: |
| 227 | + |
| 228 | + def fun(*args, func_str=func_str, units=None, **kwargs): |
| 229 | + magnitude = getattr(xp, func_str)(*args, **kwargs) |
| 230 | + return ArrayUnitQuantity(magnitude, units) |
| 231 | + |
| 232 | + setattr(mod, func_str, fun) |
| 233 | + |
214 | 234 | ## Data Type Functions and Data Types ##
|
215 | 235 | dtype_fun_names = ["can_cast", "finfo", "iinfo", "isdtype"]
|
216 | 236 | dtype_names = [
|
@@ -280,6 +300,140 @@ def func(x, /, *args, func_str=func_str, **kwargs):
|
280 | 300 |
|
281 | 301 | setattr(mod, func_str, func)
|
282 | 302 |
|
| 303 | + elementwise_one_array = [ |
| 304 | + "abs", |
| 305 | + "acos", |
| 306 | + "acosh", |
| 307 | + "asin", |
| 308 | + "asinh", |
| 309 | + "atan", |
| 310 | + "atanh", |
| 311 | + "bitwise_invert", |
| 312 | + "ceil", |
| 313 | + "conj", |
| 314 | + "cos", |
| 315 | + "cosh", |
| 316 | + "exp", |
| 317 | + "expm1", |
| 318 | + "floor", |
| 319 | + "imag", |
| 320 | + "isfinite", |
| 321 | + "isinf", |
| 322 | + "isnan", |
| 323 | + "log", |
| 324 | + "log1p", |
| 325 | + "log2", |
| 326 | + "log10", |
| 327 | + "logical_not", |
| 328 | + "negative", |
| 329 | + "positive", |
| 330 | + "real", |
| 331 | + "round", |
| 332 | + "sign", |
| 333 | + "signbit", |
| 334 | + "sin", |
| 335 | + "sinh", |
| 336 | + "square", |
| 337 | + "sqrt", |
| 338 | + "tan", |
| 339 | + "tanh", |
| 340 | + "trunc", |
| 341 | + ] |
| 342 | + for func_str in elementwise_one_array: |
| 343 | + |
| 344 | + def fun(x, /, *args, func_str=func_str, **kwargs): |
| 345 | + x = asarray(x) |
| 346 | + magnitude = xp.asarray(x.magnitude, copy=True) |
| 347 | + magnitude = getattr(xp, func_str)(x, *args, **kwargs) |
| 348 | + return ArrayUnitQuantity(magnitude, x.units) |
| 349 | + |
| 350 | + setattr(mod, func_str, fun) |
| 351 | + |
| 352 | + def _is_quantity(obj): |
| 353 | + """Test for _units and _magnitude attrs. |
| 354 | +
|
| 355 | + This is done in place of isinstance(Quantity, arg), |
| 356 | + which would cause a circular import. |
| 357 | +
|
| 358 | + Parameters |
| 359 | + ---------- |
| 360 | + obj : Object |
| 361 | +
|
| 362 | + Returns |
| 363 | + ------- |
| 364 | + bool |
| 365 | + """ |
| 366 | + return hasattr(obj, "_units") and hasattr(obj, "_magnitude") |
| 367 | + |
| 368 | + def _is_sequence_with_quantity_elements(obj): |
| 369 | + """Test for sequences of quantities. |
| 370 | +
|
| 371 | + Parameters |
| 372 | + ---------- |
| 373 | + obj : object |
| 374 | +
|
| 375 | + Returns |
| 376 | + ------- |
| 377 | + True if obj is a sequence and at least one element is a Quantity; |
| 378 | + False otherwise |
| 379 | + """ |
| 380 | + if is_array_api_obj(obj) and not obj.dtype.hasobject: |
| 381 | + # If obj is an array, avoid looping on all elements |
| 382 | + # if dtype does not have objects |
| 383 | + return False |
| 384 | + return ( |
| 385 | + iterable(obj) |
| 386 | + and sized(obj) |
| 387 | + and not isinstance(obj, str) |
| 388 | + and any(_is_quantity(item) for item in obj) |
| 389 | + ) |
| 390 | + |
| 391 | + elementwise_two_arrays = [ |
| 392 | + "add", |
| 393 | + "atan2", |
| 394 | + "bitwise_and", |
| 395 | + "bitwise_left_shift", |
| 396 | + "bitwise_or", |
| 397 | + "bitwise_right_shift", |
| 398 | + "bitwise_xor", |
| 399 | + "copysign", |
| 400 | + "divide", |
| 401 | + "equal", |
| 402 | + "floor_divide", |
| 403 | + "greater", |
| 404 | + "greater_equal", |
| 405 | + "hypot", |
| 406 | + "less", |
| 407 | + "less_equal", |
| 408 | + "logaddexp", |
| 409 | + "logical_and", |
| 410 | + "logical_or", |
| 411 | + "logical_xor", |
| 412 | + "maximum", |
| 413 | + "minimum", |
| 414 | + "multiply", |
| 415 | + "not_equal", |
| 416 | + "pow", |
| 417 | + "remainder", |
| 418 | + "subtract", |
| 419 | + ] |
| 420 | + for func_str in elementwise_two_arrays: |
| 421 | + |
| 422 | + def fun(x1, x2, /, *args, func_str=func_str, **kwargs): |
| 423 | + x1 = asarray(x1) |
| 424 | + x2 = asarray(x2) |
| 425 | + |
| 426 | + units = x1.units |
| 427 | + |
| 428 | + x1_magnitude = xp.asarray(x1.magnitude, copy=True) |
| 429 | + x2_magnitude = x2.m_as(units) |
| 430 | + |
| 431 | + xp_func = getattr(xp, func_str) |
| 432 | + magnitude = xp_func(x1_magnitude, x2_magnitude, *args, **kwargs) |
| 433 | + return ArrayUnitQuantity(magnitude, units) |
| 434 | + |
| 435 | + setattr(mod, func_str, fun) |
| 436 | + |
283 | 437 | # Handle functions with output unit defined by operation
|
284 | 438 |
|
285 | 439 | # output_unit="sum":
|
|
0 commit comments