Skip to content

Commit e2f8c14

Browse files
authored
Merge pull request #2 from sdpython/array_api
Implements the first step for the Array API, test with scikit-learn
2 parents d875d0d + 3f64e9c commit e2f8c14

12 files changed

+556
-111
lines changed

_unittests/ut__main/test_documentation_examples.py

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
import importlib
55
import subprocess
66
import time
7+
from onnx_array_api import __file__ as onnx_array_api_file
78
from onnx_array_api.ext_test_case import ExtTestCase
89

10+
VERBOSE = 0
11+
ROOT = os.path.realpath(os.path.abspath(os.path.join(onnx_array_api_file, "..", "..")))
12+
913

1014
def import_source(module_file_path, module_name):
1115
if not os.path.exists(module_file_path):
@@ -20,43 +24,58 @@ def import_source(module_file_path, module_name):
2024

2125

2226
class TestDocumentationExamples(ExtTestCase):
23-
def test_documentation_examples(self):
27+
def run_test(self, fold: str, name: str, verbose=0) -> int:
28+
ppath = os.environ.get("PYTHONPATH", "")
29+
if len(ppath) == 0:
30+
os.environ["PYTHONPATH"] = ROOT
31+
elif ROOT not in ppath:
32+
sep = ";" if sys.platform == "win32" else ":"
33+
os.environ["PYTHONPATH"] = ppath + sep + ROOT
34+
perf = time.perf_counter()
35+
try:
36+
mod = import_source(fold, os.path.splitext(name)[0])
37+
assert mod is not None
38+
except FileNotFoundError:
39+
# try another way
40+
cmds = [sys.executable, "-u", os.path.join(fold, name)]
41+
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
42+
res = p.communicate()
43+
out, err = res
44+
st = err.decode("ascii", errors="ignore")
45+
if len(st) > 0 and "Traceback" in st:
46+
if '"dot" not found in path.' in st:
47+
# dot not installed, this part
48+
# is tested in onnx framework
49+
if verbose:
50+
print(f"failed: {name!r} due to missing dot.")
51+
return 0
52+
raise AssertionError(
53+
"Example '{}' (cmd: {} - exec_prefix='{}') "
54+
"failed due to\n{}"
55+
"".format(name, cmds, sys.exec_prefix, st)
56+
)
57+
dt = time.perf_counter() - perf
58+
if verbose:
59+
print(f"{dt:.3f}: run {name!r}")
60+
return 1
61+
62+
@classmethod
63+
def add_test_methods(cls):
2464
this = os.path.abspath(os.path.dirname(__file__))
2565
fold = os.path.normpath(os.path.join(this, "..", "..", "_doc", "examples"))
2666
found = os.listdir(fold)
27-
tested = 0
2867
for name in found:
2968
if name.startswith("plot_") and name.endswith(".py"):
30-
perf = time.perf_counter()
31-
try:
32-
mod = import_source(fold, os.path.splitext(name)[0])
33-
assert mod is not None
34-
except FileNotFoundError:
35-
# try another way
36-
cmds = [sys.executable, "-u", os.path.join(fold, name)]
37-
p = subprocess.Popen(
38-
cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE
39-
)
40-
res = p.communicate()
41-
out, err = res
42-
st = err.decode("ascii", errors="ignore")
43-
if len(st) > 0 and "Traceback" in st:
44-
if '"dot" not found in path.' in st:
45-
# dot not installed, this part
46-
# is tested in onnx framework
47-
print(f"failed: {name!r} due to missing dot.")
48-
continue
49-
raise AssertionError(
50-
"Example '{}' (cmd: {} - exec_prefix='{}') "
51-
"failed due to\n{}"
52-
"".format(name, cmds, sys.exec_prefix, st)
53-
)
54-
dt = time.perf_counter() - perf
55-
print(f"{dt:.3f}: run {name!r}")
56-
tested += 1
57-
if tested == 0:
58-
raise AssertionError("No example was tested.")
69+
short_name = os.path.split(os.path.splitext(name)[0])[-1]
70+
71+
def _test_(self, name=name):
72+
res = self.run_test(fold, name, verbose=VERBOSE)
73+
self.assertTrue(res)
74+
75+
setattr(cls, f"test_{short_name}", _test_)
76+
5977

78+
TestDocumentationExamples.add_test_methods()
6079

6180
if __name__ == "__main__":
62-
unittest.main()
81+
unittest.main(verbosity=2)

_unittests/ut_npx/test_npx.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from onnx_array_api.npx.npx_functions import sinh as sinh_inline
7272
from onnx_array_api.npx.npx_functions import sqrt as sqrt_inline
7373
from onnx_array_api.npx.npx_functions import squeeze as squeeze_inline
74+
from onnx_array_api.npx.npx_functions import take as take_inline
7475
from onnx_array_api.npx.npx_functions import tan as tan_inline
7576
from onnx_array_api.npx.npx_functions import tanh as tanh_inline
7677
from onnx_array_api.npx.npx_functions import topk as topk_inline
@@ -2439,7 +2440,19 @@ def myloss(x, y):
24392440
res_eager = eager_myloss(x, y)
24402441
assert_allclose(res_jit, res_eager)
24412442

2443+
def test_take(self):
2444+
data = np.random.randn(3, 3).astype(np.float32)
2445+
indices = np.array([[0, 2]])
2446+
y = np.take(data, indices, axis=1)
2447+
2448+
f = take_inline(Input("A"), Input("B"), axis=1)
2449+
self.assertIsInstance(f, Var)
2450+
onx = f.to_onnx(constraints={"A": Float64[None], "B": Int64[None]})
2451+
ref = ReferenceEvaluator(onx)
2452+
got = ref.run(None, {"A": data, "B": indices})
2453+
self.assertEqualArray(y, got[0])
2454+
24422455

24432456
if __name__ == "__main__":
2444-
TestNpx().test_eager_cst_index()
2457+
TestNpx().test_take()
24452458
unittest.main(verbosity=2)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import unittest
2+
import numpy as np
3+
from onnx.defs import onnx_opset_version
4+
from sklearn import config_context
5+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
6+
from onnx_array_api.ext_test_case import ExtTestCase
7+
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
8+
9+
10+
DEFAULT_OPSET = onnx_opset_version()
11+
12+
13+
def take(self, X, indices, *, axis):
14+
# Overwritting method take as it is using iterators.
15+
# When array_api supports `take` we can use this directly
16+
# https://github.com/data-apis/array-api/issues/177
17+
X_np = self._namespace.take(X, indices, axis=axis)
18+
return self._namespace.asarray(X_np)
19+
20+
21+
class TestSklearnArrayAPI(ExtTestCase):
22+
def test_sklearn_array_api_linear_discriminant(self):
23+
from sklearn.utils._array_api import _ArrayAPIWrapper
24+
25+
_ArrayAPIWrapper.take = take
26+
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
27+
y = np.array([1, 1, 1, 2, 2, 2])
28+
ana = LinearDiscriminantAnalysis()
29+
ana = LinearDiscriminantAnalysis()
30+
ana.fit(X, y)
31+
expected = ana.predict(X)
32+
33+
new_x = EagerNumpyTensor(X)
34+
self.assertStartsWith("EagerNumpyTensor(array([[", repr(new_x))
35+
with config_context(array_api_dispatch=True):
36+
got = ana.predict(new_x)
37+
self.assertEqualArray(expected, got.numpy())
38+
39+
40+
if __name__ == "__main__":
41+
# import logging
42+
# logging.basicConfig(level=logging.DEBUG)
43+
unittest.main(verbosity=2)

onnx_array_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
(Numpy) Array API for ONNX.
44
"""
55

6-
__version__ = "0.1.1"
6+
__version__ = "0.1.2"
77
__author__ = "Xavier Dupré"

onnx_array_api/npx/npx_array_api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ class ArrayApi:
1010
List of supported method by a tensor.
1111
"""
1212

13+
def __array_namespace__(self):
14+
"""
15+
Returns the module holding all the available functions.
16+
"""
17+
from onnx_array_api.npx import npx_functions
18+
19+
return npx_functions
20+
1321
def generic_method(self, method_name, *args: Any, **kwargs: Any) -> Any:
1422
raise NotImplementedError(
1523
f"Method {method_name!r} must be overwritten "

onnx_array_api/npx/npx_core_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def wrapper(*inputs, **kwargs):
123123
for x in inputs:
124124
if isinstance(x, EagerTensor):
125125
tensor_class = x.__class__
126+
break
126127
if tensor_class is None:
127128
raise RuntimeError(
128129
f"Unable to find an EagerTensor in types "

onnx_array_api/npx/npx_functions.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from typing import Optional, Tuple, Union
1+
from typing import Any, Optional, Tuple, Union
22

33
import numpy as np
4-
from onnx import FunctionProto, ModelProto, NodeProto
4+
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto
5+
from onnx.helper import np_dtype_to_tensor_dtype
56
from onnx.numpy_helper import from_array
67

78
from .npx_constants import FUNCTION_DOMAIN
89
from .npx_core_api import cst, make_tuple, npxapi_inline, var
10+
from .npx_tensors import ArrayApi
911
from .npx_types import (
1012
ElemType,
1113
OptParType,
@@ -155,6 +157,50 @@ def arctanh(
155157
return var(x, op="Atanh")
156158

157159

160+
def asarray(
161+
a: Any,
162+
dtype: Any = None,
163+
order: Optional[str] = None,
164+
like: Any = None,
165+
copy: bool = False,
166+
):
167+
"""
168+
Converts anything into an array.
169+
"""
170+
if dtype is not None:
171+
raise RuntimeError("Method 'astype' should be used to change the type.")
172+
if order is not None:
173+
raise NotImplementedError(f"order={order!r} not implemented.")
174+
if isinstance(a, ArrayApi):
175+
if copy:
176+
return a.__class__(a, copy=copy)
177+
return a
178+
raise NotImplementedError(f"asarray not implemented for type {type(a)}.")
179+
180+
181+
@npxapi_inline
182+
def astype(
183+
a: TensorType[ElemType.numerics, "T1"], dtype: OptParType[int] = 1
184+
) -> TensorType[ElemType.numerics, "T2"]:
185+
"""
186+
Cast an array.
187+
"""
188+
if isinstance(dtype, Var):
189+
raise TypeError(
190+
f"dtype is an attribute, it cannot be a Variable of type {type(dtype)}."
191+
)
192+
try:
193+
to = np_dtype_to_tensor_dtype(dtype)
194+
except KeyError:
195+
if dtype is int:
196+
to = TensorProto.INT64
197+
elif dtype is float:
198+
to = TensorProto.float64
199+
else:
200+
raise ValueError(f"Unable to guess tensor type from {dtype}.")
201+
return var(a, op="Cast", to=to)
202+
203+
158204
@npxapi_inline
159205
def cdist(
160206
xa: TensorType[ElemType.numerics, "T"],
@@ -412,6 +458,17 @@ def relu(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics,
412458
return var(x, op="Relu")
413459

414460

461+
@npxapi_inline
462+
def reshape(
463+
x: TensorType[ElemType.numerics, "T"], shape: TensorType[ElemType.int64, "I"]
464+
) -> TensorType[ElemType.numerics, "T"]:
465+
"See :func:`numpy.reshape`."
466+
if isinstance(shape, int):
467+
shape = cst(np.array([shape], dtype=np.int64))
468+
shape_reshaped = var(shape, cst(np.array([-1], dtype=np.int64)), op="Reshape")
469+
return var(x, shape_reshaped, op="Reshape")
470+
471+
415472
@npxapi_inline
416473
def round(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]:
417474
"See :func:`numpy.round`."
@@ -469,6 +526,16 @@ def squeeze(
469526
return var(x, axis, op="Squeeze")
470527

471528

529+
@npxapi_inline
530+
def take(
531+
data: TensorType[ElemType.numerics, "T"],
532+
indices: TensorType[ElemType.int64, "I"],
533+
axis: ParType[int] = 0,
534+
) -> TensorType[ElemType.numerics, "T"]:
535+
"See :func:`numpy.take`."
536+
return var(data, indices, op="Gather", axis=axis)
537+
538+
472539
@npxapi_inline
473540
def tan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]:
474541
"See :func:`numpy.tan`."

onnx_array_api/npx/npx_graph_builder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,13 @@ def make_node(
235235
f"Cannot create a node Identity for {len(inputs)} input(s) and "
236236
f"{len(outputs)} output(s)."
237237
)
238-
node = make_node(op, inputs, outputs, domain=domain, **new_kwargs)
238+
try:
239+
node = make_node(op, inputs, outputs, domain=domain, **new_kwargs)
240+
except TypeError as e:
241+
raise TypeError(
242+
f"Unable to create node {op!r}, with inputs={inputs}, "
243+
f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}."
244+
) from e
239245
for p in protos:
240246
node.attribute.append(p)
241247
if attribute_protos is not None:

0 commit comments

Comments
 (0)