Skip to content

Commit 701fa91

Browse files
authored
Update tests. Add randn (#7)
* Update tests. Add randn * Fix import. Fix tests. Fix backend imports * Fix licence * Fix imports * Add newer backend imports * Fix eval
1 parent fecdb91 commit 701fa91

File tree

10 files changed

+437
-58
lines changed

10 files changed

+437
-58
lines changed

arrayfire/__init__.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -533,9 +533,9 @@
533533
trunc,
534534
)
535535

536-
__all__ += ["randu"]
536+
__all__ += ["randn", "randu"]
537537

538-
from arrayfire.library.random import randu
538+
from arrayfire.library.random import randn, randu
539539

540540
__all__ += [
541541
"fft",
@@ -591,23 +591,26 @@
591591

592592
from arrayfire.library.statistics import corrcoef, cov, mean, median, stdev, topk, var
593593

594-
__all__ += [
595-
"get_active_backend",
596-
"get_available_backends",
597-
"get_backend_count",
598-
"get_backend_id",
599-
"get_device_id",
600-
"set_backend",
601-
]
602-
603-
from arrayfire.library.unified_api_functions import (
604-
get_active_backend,
605-
get_available_backends,
606-
get_backend_count,
607-
get_backend_id,
608-
get_device_id,
609-
set_backend,
610-
)
594+
# TODO
595+
# Temp solution. Remove when arrayfire-binary-python-wrapper is finalized
596+
597+
# __all__ += [
598+
# "get_active_backend",
599+
# "get_available_backends",
600+
# "get_backend_count",
601+
# "get_backend_id",
602+
# "get_device_id",
603+
# "set_backend",
604+
# ]
605+
606+
# from arrayfire.library.unified_api_functions import (
607+
# get_active_backend,
608+
# get_available_backends,
609+
# get_backend_count,
610+
# get_backend_id,
611+
# get_device_id,
612+
# set_backend,
613+
# )
611614

612615
__all__ += [
613616
"accum",
@@ -656,3 +659,9 @@
656659
__all__ += ["cast"]
657660

658661
from arrayfire.library.utils import cast
662+
663+
# Backend
664+
665+
__all__ += ["set_backend", "get_backend", "BackendType"]
666+
667+
from arrayfire_wrapper import BackendType, get_backend, set_backend

arrayfire/library/array_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def copy_array(array: Array, /) -> Array:
241241
def eval(*arrays: Array) -> None:
242242
if len(arrays) == 1:
243243
wrapper.eval(arrays[0].arr)
244+
return
244245

245246
arrs = [array.arr for array in arrays]
246247
wrapper.eval_multiple(len(arrays), *arrs)

arrayfire/library/linear_algebra.py

Lines changed: 130 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,43 @@ def dot(
2929
lhs: Array,
3030
rhs: Array,
3131
/,
32-
lhs_opts: MatProp = MatProp.NONE,
33-
rhs_opts: MatProp = MatProp.NONE,
3432
*,
3533
return_scalar: bool = False,
3634
) -> int | float | complex | Array:
35+
"""
36+
Calculates the dot product of two input arrays, with options to modify the operation
37+
on the input arrays and the possibility to return the result as a scalar.
38+
39+
Parameters
40+
----------
41+
lhs : Array
42+
A 1-dimensional, int of float Array instance, representing an array.
43+
44+
rhs : Array
45+
A 1-dimensional, int of float Array instance, representing another array.
46+
47+
return_scalar : bool, optional
48+
When set to True, the input arrays are flattened, and the output is a scalar value.
49+
Default is False.
50+
51+
Returns
52+
-------
53+
out : int | float | complex | Array
54+
The result of the dot product. Returns an Array unless `return_scalar` is True,
55+
in which case a scalar value (int, float, or complex) is returned based on the
56+
data type of the inputs.
57+
58+
Note
59+
-----
60+
- The data types of `lhs` and `rhs` should be the same.
61+
- Batch operations are not supported.
62+
- Modification options for `lhs` and `rhs` are currently disabled as function supports only `MatProp.NONE`.
63+
"""
64+
# TODO
65+
# Add support of lhs_opts and rhs_opts and return them as key arguments.
66+
lhs_opts: MatProp = MatProp.NONE
67+
rhs_opts: MatProp = MatProp.NONE
68+
3769
if return_scalar:
3870
return wrapper.dot_all(lhs.arr, rhs.arr, lhs_opts, rhs_opts)
3971

@@ -50,11 +82,105 @@ def gemm(
5082
alpha: int | float = 1.0,
5183
beta: int | float = 0.0,
5284
) -> Array:
85+
"""
86+
Performs BLAS general matrix multiplication (GEMM) on two Array instances.
87+
88+
The operation is defined as: C = alpha * op(lhs) * op(rhs) + beta * C, where op(X) is
89+
one of no operation, transpose, or Hermitian transpose, determined by lhs_opts and rhs_opts.
90+
91+
Parameters
92+
----------
93+
lhs : Array
94+
A 2-dimensional, real or complex array representing the left-hand side matrix.
95+
96+
rhs : Array
97+
A 2-dimensional, real or complex array representing the right-hand side matrix.
98+
99+
lhs_opts : MatProp, optional
100+
Operation to perform on `lhs` before multiplication. Default is MatProp.NONE. Options include:
101+
- MatProp.NONE: No operation.
102+
- MatProp.TRANS: Transpose.
103+
- MatProp.CTRANS: Hermitian transpose.
104+
105+
rhs_opts : MatProp, optional
106+
Operation to perform on `rhs` before multiplication. Default is MatProp.NONE. Options include:
107+
- MatProp.NONE: No operation.
108+
- MatProp.TRANS: Transpose.
109+
- MatProp.CTRANS: Hermitian transpose.
110+
111+
alpha : int | float, optional
112+
Scalar multiplier for the product of `lhs` and `rhs`. Default is 1.0.
113+
114+
beta : int | float, optional
115+
Scalar multiplier for the existing matrix C in the accumulation. Default is 0.0.
116+
117+
Returns
118+
-------
119+
Array
120+
The result of the matrix multiplication operation.
121+
122+
Note
123+
-----
124+
- The data types of `lhs` and `rhs` must be compatible.
125+
- Batch operations are not supported in this version.
126+
"""
53127
return cast(Array, wrapper.gemm(lhs.arr, rhs.arr, lhs_opts, rhs_opts, alpha, beta))
54128

55129

56130
@afarray_as_array
57131
def matmul(lhs: Array, rhs: Array, /, lhs_opts: MatProp = MatProp.NONE, rhs_opts: MatProp = MatProp.NONE) -> Array:
132+
"""
133+
Performs generalized matrix multiplication between two arrays with optional
134+
transposition or hermitian transposition operations on the input matrices.
135+
136+
Parameters
137+
----------
138+
lhs : af.Array
139+
A 2-dimensional, real or complex ArrayFire array representing the left-hand side matrix.
140+
141+
rhs : af.Array
142+
A 2-dimensional, real or complex ArrayFire array representing the right-hand side matrix.
143+
144+
lhs_opts : af.MATPROP, optional
145+
Operation to perform on the `lhs` matrix before multiplication. Defaults to af.MATPROP.NONE.
146+
Options include:
147+
- af.MATPROP.NONE: No operation.
148+
- af.MATPROP.TRANS: Transpose `lhs`.
149+
- af.MATPROP.CTRANS: Hermitian transpose (conjugate transpose) `lhs`.
150+
151+
rhs_opts : af.MATPROP, optional
152+
Operation to perform on the `rhs` matrix before multiplication. Defaults to af.MATPROP.NONE.
153+
Options include:
154+
- af.MATPROP.NONE: No operation.
155+
- af.MATPROP.TRANS: Transpose `rhs`.
156+
- af.MATPROP.CTRANS: Hermitian transpose (conjugate transpose) `rhs`.
157+
158+
Returns
159+
-------
160+
out : af.Array
161+
The result of the matrix multiplication. The output is a 2-dimensional ArrayFire array.
162+
163+
Notes
164+
-----
165+
- The data types of `lhs` and `rhs` must be the same.
166+
- Batch operations (multiplying multiple pairs of matrices at once) are not supported in this implementation.
167+
168+
Examples
169+
--------
170+
Basic matrix multiplication:
171+
172+
A = af.randu(5, 4, dtype=af.Dtype.f32)
173+
B = af.randu(4, 6, dtype=af.Dtype.f32)
174+
C = matmul(A, B)
175+
176+
Matrix multiplication with the left-hand side transposed:
177+
178+
C = matmul(A, B, lhs_opts=af.MATPROP.TRANS)
179+
180+
Matrix multiplication with both matrices transposed:
181+
182+
C = matmul(A, B, lhs_opts=af.MATPROP.TRANS, rhs_opts=af.MATPROP.TRANS)
183+
"""
58184
return cast(Array, wrapper.matmul(lhs.arr, rhs.arr, lhs_opts, rhs_opts))
59185

60186

@@ -121,5 +247,5 @@ def solve(a: Array, b: Array, /, *, options: MatProp = MatProp.NONE, pivot: None
121247
return cast(Array, wrapper.solve(a.arr, b.arr, options))
122248

123249

124-
# TODO
125-
# Add Sparse functions? #good_first_issue
250+
# TODO #good_first_issue
251+
# Add Sparse functions

arrayfire/library/random.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,52 @@ def from_engine(cls, engine: wrapper.AFRandomEngineHandle) -> RandomEngine:
131131
return instance
132132

133133

134+
@afarray_as_array
135+
def randn(shape: tuple[int, ...], /, *, dtype: Dtype = float32, engine: RandomEngine | None = None) -> Array:
136+
"""
137+
Create a multi-dimensional array containing values sampled from a normal distribution with mean 0
138+
and standard deviation of 1.
139+
140+
Parameters
141+
----------
142+
shape : tuple[int, ...]
143+
The shape of the resulting array. Must be a tuple with at least one element, e.g., shape=(3,).
144+
145+
dtype : Dtype, optional, default: `float32`
146+
The data type of the array elements.
147+
148+
engine : RandomEngine | None, optional
149+
The random number generator engine to be used. If None, uses a default engine created by ArrayFire.
150+
151+
Returns
152+
-------
153+
Array
154+
A multi-dimensional array whose elements are sampled from a normal distribution. The dimensions of the array
155+
are determined by `shape`:
156+
- If shape is (x,), the output is a 1D array of size (x,).
157+
- If shape is (x, y), the output is a 2D array of size (x, y).
158+
- If shape is (x, y, z), the output is a 3D array of size (x, y, z).
159+
- For more dimensions, the output shape corresponds directly to the specified `shape` tuple.
160+
161+
Notes
162+
-----
163+
The function supports creating arrays of up to N dimensions, where N is determined by the length
164+
of the `shape` tuple.
165+
166+
Raises
167+
------
168+
ValueError
169+
If `shape` is not a tuple or has less than one value.
170+
"""
171+
if not isinstance(shape, tuple) or not shape:
172+
raise ValueError("Argument shape must be a tuple with at least 1 value.")
173+
174+
if engine is None:
175+
return cast(Array, wrapper.randn(shape, dtype))
176+
177+
return cast(Array, wrapper.random_normal(shape, dtype, engine.get_engine()))
178+
179+
134180
@afarray_as_array
135181
def randu(shape: tuple[int, ...], /, *, dtype: Dtype = float32, engine: RandomEngine | None = None) -> Array:
136182
"""
Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
1-
__all__ = [
2-
"get_active_backend",
3-
"get_available_backends",
4-
"get_backend_count",
5-
"get_backend_id",
6-
"get_device_id",
7-
"set_backend",
8-
]
1+
# TODO
2+
# Temp solution. Remove when arrayfire-binary-python-wrapper is finalized
93

10-
from arrayfire_wrapper.lib import get_active_backend, get_available_backends, get_backend_count
11-
from arrayfire_wrapper.lib import get_backend_id as wrapped_get_backend_id
12-
from arrayfire_wrapper.lib import get_device_id as wrapped_get_device_id
13-
from arrayfire_wrapper.lib import set_backend
4+
# __all__ = [
5+
# "get_active_backend",
6+
# "get_available_backends",
7+
# "get_backend_count",
8+
# "get_backend_id",
9+
# "get_device_id",
10+
# "set_backend",
11+
# ]
1412

15-
from arrayfire import Array
13+
# from arrayfire_wrapper.lib import get_active_backend, get_available_backends, get_backend_count
14+
# from arrayfire_wrapper.lib import get_backend_id as wrapped_get_backend_id
15+
# from arrayfire_wrapper.lib import get_device_id as wrapped_get_device_id
16+
# from arrayfire_wrapper.lib import set_backend
1617

18+
# from arrayfire import Array
1719

18-
def get_backend_id(array: Array) -> int:
19-
return wrapped_get_backend_id(array.arr)
2020

21+
# def get_backend_id(array: Array) -> int:
22+
# return wrapped_get_backend_id(array.arr)
2123

22-
def get_device_id(array: Array) -> int:
23-
return wrapped_get_device_id(array.arr)
24+
25+
# def get_device_id(array: Array) -> int:
26+
# return wrapped_get_device_id(array.arr)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
arrayfire-python-wrapper==0.5.0+af3.9.0
1+
arrayfire-binary-python-wrapper==0.6.0+af3.9.0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def fix_url_dependencies(req: str) -> str:
6868
name="arrayfire",
6969
version=VERSION["VERSION"],
7070
description="ArrayFire Python Wrapper",
71-
licence="BSD",
71+
license="BSD",
7272
long_description=(ABS_PATH / "README.md").open("r").read(),
7373
long_description_content_type="text/markdown",
7474
author="ArrayFire",

tests/_helpers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
import arrayfire as af
2+
3+
14
def round_to(list_: list[int | float | complex | bool], symbols: int = 3) -> list[int | float]:
25
# HACK replace for e.g. abs(x1-x2) < 1e-6 ~ https://davidamos.dev/the-right-way-to-compare-floats-in-python/
36
return [round(x, symbols) for x in list_]
7+
8+
9+
def create_from_2d_nested(x1: float, x2: float, x3: float, x4: float, dtype: af.Dtype = af.float32) -> af.Array:
10+
array = af.randu((2, 2), dtype=dtype)
11+
array[0, 0] = x1
12+
array[0, 1] = x2
13+
array[1, 0] = x3
14+
array[1, 1] = x4
15+
return array

0 commit comments

Comments
 (0)