Skip to content

Commit 54e4594

Browse files
authored
Add overloads. Fix typings (#16)
* Add overloads. Fix examples. Add missing reqs
1 parent 65b5ccd commit 54e4594

24 files changed

+701
-317
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ endif
1414
version :
1515
@python -c 'from arrayfire.version import VERSION; print(f"ArrayFire Python v{VERSION}")'
1616

17+
.PHONY : build
18+
build :
19+
@python -m build
20+
1721
# Dev
1822

1923
.PHONY : pre-commit

arrayfire/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,13 +562,23 @@
562562
"approx1_uniform",
563563
"approx2",
564564
"approx2_uniform",
565+
"convolve1",
566+
"convolve2",
567+
"convolve2_nn",
568+
"convolve2_separable",
569+
"convolve3",
565570
]
566571

567572
from arrayfire.library.signal_processing import (
568573
approx1,
569574
approx1_uniform,
570575
approx2,
571576
approx2_uniform,
577+
convolve1,
578+
convolve2,
579+
convolve2_nn,
580+
convolve2_separable,
581+
convolve3,
572582
fft,
573583
fft2,
574584
fft2_c2r,

arrayfire/array_api/_sorting_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def argsort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bo
3535
if axis == -1:
3636
axis = x.ndim - 1
3737

38-
_, indices = af.sort(x._array, axis=axis, is_ascending=not descending, is_index_array=True) # type: ignore[misc]
38+
_, indices = af.sort(x._array, axis=axis, is_ascending=not descending, is_index_array=True)
3939
return Array._new(indices)
4040

4141

arrayfire/array_object.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -549,41 +549,41 @@ def __ne__(self, other: int | float | bool | Array, /) -> Array: # type: ignore
549549

550550
# Reflected Arithmetic Operators
551551

552-
def __radd__(self, other: Array, /) -> Array:
552+
def __radd__(self, other: int | float | Array, /) -> Array:
553553
"""
554554
Return other + self.
555555
"""
556556
return process_c_function(other, self, wrapper.add)
557557

558-
def __rsub__(self, other: Array, /) -> Array:
558+
def __rsub__(self, other: int | float | Array, /) -> Array:
559559
"""
560560
Return other - self.
561561
"""
562562
return process_c_function(other, self, wrapper.sub)
563563

564-
def __rmul__(self, other: Array, /) -> Array:
564+
def __rmul__(self, other: int | float | Array, /) -> Array:
565565
"""
566566
Return other * self.
567567
"""
568568
return process_c_function(other, self, wrapper.mul)
569569

570-
def __rtruediv__(self, other: Array, /) -> Array:
570+
def __rtruediv__(self, other: int | float | Array, /) -> Array:
571571
"""
572572
Return other / self.
573573
"""
574574
return process_c_function(other, self, wrapper.div)
575575

576-
def __rfloordiv__(self, other: Array, /) -> Array:
576+
def __rfloordiv__(self, other: int | float | Array, /) -> Array:
577577
# TODO
578578
return NotImplemented
579579

580-
def __rmod__(self, other: Array, /) -> Array:
580+
def __rmod__(self, other: int | float | Array, /) -> Array:
581581
"""
582582
Return other % self.
583583
"""
584584
return process_c_function(other, self, wrapper.mod)
585585

586-
def __rpow__(self, other: Array, /) -> Array:
586+
def __rpow__(self, other: int | float | Array, /) -> Array:
587587
"""
588588
Return other ** self.
589589
"""
@@ -597,31 +597,31 @@ def __rmatmul__(self, other: Array, /) -> Array:
597597

598598
# Reflected Bitwise Operators
599599

600-
def __rand__(self, other: Array, /) -> Array:
600+
def __rand__(self, other: int | bool | Array, /) -> Array:
601601
"""
602602
Return other & self.
603603
"""
604604
return process_c_function(other, self, wrapper.bitand)
605605

606-
def __ror__(self, other: Array, /) -> Array:
606+
def __ror__(self, other: int | bool | Array, /) -> Array:
607607
"""
608608
Return other | self.
609609
"""
610610
return process_c_function(other, self, wrapper.bitor)
611611

612-
def __rxor__(self, other: Array, /) -> Array:
612+
def __rxor__(self, other: int | bool | Array, /) -> Array:
613613
"""
614614
Return other ^ self.
615615
"""
616616
return process_c_function(other, self, wrapper.bitxor)
617617

618-
def __rlshift__(self, other: Array, /) -> Array:
618+
def __rlshift__(self, other: int | Array, /) -> Array:
619619
"""
620620
Return other << self.
621621
"""
622622
return process_c_function(other, self, wrapper.bitshiftl)
623623

624-
def __rrshift__(self, other: Array, /) -> Array:
624+
def __rrshift__(self, other: int | Array, /) -> Array:
625625
"""
626626
Return other >> self.
627627
"""
@@ -1126,11 +1126,11 @@ def process_c_function(lhs: int | float | Array, rhs: int | float | Array, c_fun
11261126
lhs_array = lhs.arr
11271127
rhs_array = rhs.arr
11281128

1129-
elif isinstance(lhs, Array) and isinstance(rhs, (int, float)):
1129+
elif isinstance(lhs, Array) and isinstance(rhs, int | float):
11301130
lhs_array = lhs.arr
11311131
rhs_array = wrapper.create_constant_array(rhs, lhs.shape, lhs.dtype)
11321132

1133-
elif isinstance(lhs, (int, float)) and isinstance(rhs, Array):
1133+
elif isinstance(lhs, int | float) and isinstance(rhs, Array):
11341134
lhs_array = wrapper.create_constant_array(lhs, rhs.shape, rhs.dtype)
11351135
rhs_array = rhs.arr
11361136

arrayfire/library/device.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,25 @@
4747
set_device,
4848
set_kernel_cache_directory,
4949
set_mem_step_size,
50-
sync,
5150
)
51+
from arrayfire_wrapper.lib import sync as wrapper_sync
52+
53+
54+
def sync(device_id: int | None = None) -> None:
55+
"""
56+
Blocks until all the functions on the specified device have completed execution.
57+
58+
This function is used to synchronize the program execution with the operations
59+
being carried out on a GPU or other computation device, ensuring that all
60+
previously submitted operations are complete before the program proceeds.
61+
62+
Parameters
63+
----------
64+
device_id : int | None, optional
65+
The ID of the device on which to wait for all operations to complete.
66+
If None is provided, the current active device is used. Default is None.
67+
"""
68+
if device_id is None:
69+
device_id = get_device()
70+
71+
wrapper_sync(device_id)

arrayfire/library/linear_algebra.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"solve",
1616
]
1717

18-
from typing import cast
18+
from typing import Literal, cast, overload
1919

2020
import arrayfire_wrapper.lib as wrapper
2121
from arrayfire_wrapper.lib import is_lapack_available
@@ -24,6 +24,17 @@
2424
from arrayfire.array_object import afarray_as_array
2525
from arrayfire.library.constants import MatProp, Norm
2626

27+
# TODO
28+
# Add missing documentation
29+
30+
31+
@overload
32+
def dot(lhs: Array, rhs: Array, /, *, return_scalar: Literal[True]) -> int | float | complex: ...
33+
34+
35+
@overload
36+
def dot(lhs: Array, rhs: Array, /, *, return_scalar: Literal[False] = False) -> Array: ...
37+
2738

2839
def dot(
2940
lhs: Array,
@@ -184,6 +195,14 @@ def matmul(lhs: Array, rhs: Array, /, lhs_opts: MatProp = MatProp.NONE, rhs_opts
184195
return cast(Array, wrapper.matmul(lhs.arr, rhs.arr, lhs_opts, rhs_opts))
185196

186197

198+
@overload
199+
def cholesky(array: Array, /, is_upper: bool = True, *, inplace: Literal[True]) -> int: ...
200+
201+
202+
@overload
203+
def cholesky(array: Array, /, is_upper: bool = True, *, inplace: Literal[False] = False) -> tuple[Array, int]: ...
204+
205+
187206
def cholesky(array: Array, /, is_upper: bool = True, *, inplace: bool = False) -> int | tuple[Array, int]:
188207
if inplace:
189208
return wrapper.cholesky_inplace(array.arr, is_upper)
@@ -192,6 +211,14 @@ def cholesky(array: Array, /, is_upper: bool = True, *, inplace: bool = False) -
192211
return Array.from_afarray(matrix), info
193212

194213

214+
@overload
215+
def lu(array: Array, /, *, inplace: Literal[True], is_lapack_pivot: bool = True) -> Array: ...
216+
217+
218+
@overload
219+
def lu(array: Array, /, *, inplace: Literal[False] = False, is_lapack_pivot: bool = True) -> tuple[Array, ...]: ...
220+
221+
195222
def lu(array: Array, /, *, inplace: bool = False, is_lapack_pivot: bool = True) -> Array | tuple[Array, ...]:
196223
if inplace:
197224
return Array.from_afarray(wrapper.lu_inplace(array.arr, is_lapack_pivot))
@@ -200,6 +227,14 @@ def lu(array: Array, /, *, inplace: bool = False, is_lapack_pivot: bool = True)
200227
return Array.from_afarray(lower), Array.from_afarray(upper), Array.from_afarray(pivot)
201228

202229

230+
@overload
231+
def qr(array: Array, /, *, inplace: Literal[True]) -> Array: ...
232+
233+
234+
@overload
235+
def qr(array: Array, /, *, inplace: Literal[False] = False) -> tuple[Array, ...]: ...
236+
237+
203238
def qr(array: Array, /, *, inplace: bool = False) -> Array | tuple[Array, ...]:
204239
if inplace:
205240
return Array.from_afarray(wrapper.qr_inplace(array.arr))
@@ -247,5 +282,5 @@ def solve(a: Array, b: Array, /, *, options: MatProp = MatProp.NONE, pivot: None
247282
return cast(Array, wrapper.solve(a.arr, b.arr, options))
248283

249284

250-
# TODO #good_first_issue
251-
# Add Sparse functions
285+
# TODO
286+
# Create issues as #good_first_issue: add Sparse functions

arrayfire/library/signal_processing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
"approx1_uniform",
2222
"approx2",
2323
"approx2_uniform",
24+
"convolve1",
25+
"convolve2",
26+
"convolve2_nn",
27+
"convolve2_separable",
28+
"convolve3",
2429
]
2530

2631
from typing import cast

arrayfire/library/statistics.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
__all__ = ["corrcoef", "cov", "mean", "median", "stdev", "topk", "var"]
22

3-
from typing import cast
3+
from typing import cast, overload
44

55
import arrayfire_wrapper.lib as wrapper
66

77
from arrayfire import Array
88
from arrayfire.array_object import afarray_as_array
99
from arrayfire.library.constants import TopK, VarianceBias
1010

11+
# TODO
12+
# Add missing documentation
13+
1114

1215
def corrcoef(x: Array, y: Array, /) -> int | float | complex:
1316
return wrapper.corrcoef(x.arr, y.arr)
@@ -18,6 +21,22 @@ def cov(x: Array, y: Array, /, *, bias: VarianceBias = VarianceBias.DEFAULT) ->
1821
return cast(Array, wrapper.cov(x.arr, y.arr, bias))
1922

2023

24+
@overload
25+
def mean(x: Array, /, axis: None = None, *, weights: None = None) -> int | float | complex: ...
26+
27+
28+
@overload
29+
def mean(x: Array, /, axis: int, *, weights: None = None) -> Array: ...
30+
31+
32+
@overload
33+
def mean(x: Array, /, axis: None, *, weights: Array) -> int | float | complex: ...
34+
35+
36+
@overload
37+
def mean(x: Array, /, axis: int, *, weights: Array) -> Array: ...
38+
39+
2140
def mean(x: Array, /, axis: None | int = None, *, weights: None | Array = None) -> int | float | complex | Array:
2241
if weights:
2342
if axis is None:
@@ -31,13 +50,29 @@ def mean(x: Array, /, axis: None | int = None, *, weights: None | Array = None)
3150
return Array.from_afarray(wrapper.mean(x.arr, axis))
3251

3352

53+
@overload
54+
def median(x: Array, /, axis: None = None) -> int | float | complex: ...
55+
56+
57+
@overload
58+
def median(x: Array, /, axis: int) -> Array: ...
59+
60+
3461
def median(x: Array, /, axis: None | int = None) -> int | float | complex | Array:
3562
if axis is None:
3663
return wrapper.median_all(x.arr)
3764

3865
return Array.from_afarray(wrapper.median(x.arr, axis))
3966

4067

68+
@overload
69+
def stdev(x: Array, /, axis: None = None, *, bias: VarianceBias = VarianceBias.DEFAULT) -> int | float | complex: ...
70+
71+
72+
@overload
73+
def stdev(x: Array, /, axis: int, *, bias: VarianceBias = VarianceBias.DEFAULT) -> int | float | complex: ...
74+
75+
4176
def stdev(
4277
x: Array, /, axis: None | int = None, *, bias: VarianceBias = VarianceBias.DEFAULT
4378
) -> int | float | complex | Array:
@@ -52,6 +87,26 @@ def topk(x: Array, k: int, /, *, axis: int = 0, order: TopK = TopK.DEFAULT) -> t
5287
return Array.from_afarray(values), Array.from_afarray(indices)
5388

5489

90+
@overload
91+
def var(
92+
x: Array, /, axis: None = None, *, weights: None = None, bias: VarianceBias = VarianceBias.DEFAULT
93+
) -> int | float | complex: ...
94+
95+
96+
@overload
97+
def var(x: Array, /, axis: int, *, weights: None = None, bias: VarianceBias = VarianceBias.DEFAULT) -> Array: ...
98+
99+
100+
@overload
101+
def var(
102+
x: Array, /, axis: None, *, weights: Array, bias: VarianceBias = VarianceBias.DEFAULT
103+
) -> int | float | complex: ...
104+
105+
106+
@overload
107+
def var(x: Array, /, axis: int, *, weights: Array, bias: VarianceBias = VarianceBias.DEFAULT) -> Array: ...
108+
109+
55110
def var(
56111
x: Array,
57112
/,

0 commit comments

Comments
 (0)