Skip to content

Commit 3f64e9c

Browse files
committed
fix eager mode with sklearn
1 parent b758cb7 commit 3f64e9c

File tree

6 files changed

+89
-13
lines changed

6 files changed

+89
-13
lines changed

_unittests/ut_npx/test_npx.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from onnx_array_api.npx.npx_functions import sinh as sinh_inline
7272
from onnx_array_api.npx.npx_functions import sqrt as sqrt_inline
7373
from onnx_array_api.npx.npx_functions import squeeze as squeeze_inline
74+
from onnx_array_api.npx.npx_functions import take as take_inline
7475
from onnx_array_api.npx.npx_functions import tan as tan_inline
7576
from onnx_array_api.npx.npx_functions import tanh as tanh_inline
7677
from onnx_array_api.npx.npx_functions import topk as topk_inline
@@ -2439,7 +2440,19 @@ def myloss(x, y):
24392440
res_eager = eager_myloss(x, y)
24402441
assert_allclose(res_jit, res_eager)
24412442

2443+
def test_take(self):
2444+
data = np.random.randn(3, 3).astype(np.float32)
2445+
indices = np.array([[0, 2]])
2446+
y = np.take(data, indices, axis=1)
2447+
2448+
f = take_inline(Input("A"), Input("B"), axis=1)
2449+
self.assertIsInstance(f, Var)
2450+
onx = f.to_onnx(constraints={"A": Float64[None], "B": Int64[None]})
2451+
ref = ReferenceEvaluator(onx)
2452+
got = ref.run(None, {"A": data, "B": indices})
2453+
self.assertEqualArray(y, got[0])
2454+
24422455

24432456
if __name__ == "__main__":
2444-
TestNpx().test_eager_cst_index()
2457+
TestNpx().test_take()
24452458
unittest.main(verbosity=2)

_unittests/ut_npx/test_sklearn_array_api.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,19 @@
1010
DEFAULT_OPSET = onnx_opset_version()
1111

1212

13+
def take(self, X, indices, *, axis):
14+
# Overwritting method take as it is using iterators.
15+
# When array_api supports `take` we can use this directly
16+
# https://github.com/data-apis/array-api/issues/177
17+
X_np = self._namespace.take(X, indices, axis=axis)
18+
return self._namespace.asarray(X_np)
19+
20+
1321
class TestSklearnArrayAPI(ExtTestCase):
1422
def test_sklearn_array_api_linear_discriminant(self):
23+
from sklearn.utils._array_api import _ArrayAPIWrapper
24+
25+
_ArrayAPIWrapper.take = take
1526
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
1627
y = np.array([1, 1, 1, 2, 2, 2])
1728
ana = LinearDiscriminantAnalysis()
@@ -23,11 +34,10 @@ def test_sklearn_array_api_linear_discriminant(self):
2334
self.assertStartsWith("EagerNumpyTensor(array([[", repr(new_x))
2435
with config_context(array_api_dispatch=True):
2536
got = ana.predict(new_x)
26-
self.assertEqualArray(expected, got)
37+
self.assertEqualArray(expected, got.numpy())
2738

2839

2940
if __name__ == "__main__":
30-
import logging
31-
32-
logging.basicConfig(level=logging.DEBUG)
41+
# import logging
42+
# logging.basicConfig(level=logging.DEBUG)
3343
unittest.main(verbosity=2)

onnx_array_api/npx/npx_functions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,16 @@ def squeeze(
526526
return var(x, axis, op="Squeeze")
527527

528528

529+
@npxapi_inline
530+
def take(
531+
data: TensorType[ElemType.numerics, "T"],
532+
indices: TensorType[ElemType.int64, "I"],
533+
axis: ParType[int] = 0,
534+
) -> TensorType[ElemType.numerics, "T"]:
535+
"See :func:`numpy.take`."
536+
return var(data, indices, op="Gather", axis=axis)
537+
538+
529539
@npxapi_inline
530540
def tan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]:
531541
"See :func:`numpy.tan`."

onnx_array_api/npx/npx_jit_eager.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,15 @@ def make_key(*values, **kwargs):
145145
elif isinstance(sk, (int, float)):
146146
res.append(("slice", sk))
147147
else:
148-
raise TypeError(f"Input {iv} cannot be such tuple: {v}.")
148+
raise TypeError(f"Input {iv} cannot have such tuple: {v}.")
149149
res.append(tuple(subkey))
150150
else:
151151
raise TypeError(
152152
f"Unable to build a key, input {iv} has type {type(v)}."
153153
)
154154
if kwargs:
155155
for k, v in sorted(kwargs.items()):
156-
if isinstance(v, (int, float, str)):
156+
if isinstance(v, (int, float, str, type)):
157157
res.append(k)
158158
res.append(v)
159159
elif isinstance(v, tuple):
@@ -307,8 +307,8 @@ def move_input_to_kwargs(
307307
:return: new values, new arguments
308308
"""
309309
if self.input_to_kwargs_ is None:
310-
if self.bypass_eager or self.f.__annotations__:
311-
return values, kwargs
310+
# if self.bypass_eager or self.f.__annotations__:
311+
# return values, kwargs
312312
raise RuntimeError(
313313
f"self.input_to_kwargs_ is not initialized for function {self.f} "
314314
f"from module {self.f.__module__!r}."
@@ -321,8 +321,8 @@ def move_input_to_kwargs(
321321
if i in self.input_to_kwargs_:
322322
new_kwargs[self.input_to_kwargs_[i]] = v
323323
else:
324-
new_values.append(values)
325-
return new_values, new_kwargs
324+
new_values.append(v)
325+
return tuple(new_values), new_kwargs
326326

327327
def jit_call(self, *values, **kwargs):
328328
"""
@@ -334,15 +334,33 @@ def jit_call(self, *values, **kwargs):
334334
and returns the result or the results in a tuple if there are several.
335335
"""
336336
self.info("+", "jit_call")
337+
if self.input_to_kwargs_ is None:
338+
# No jitting was ever called.
339+
onx, fct = self.to_jit(*values, **kwargs)
340+
if self.input_to_kwargs_ is None:
341+
raise RuntimeError(
342+
f"Attribute 'input_to_kwargs_' should be set for "
343+
f"function {self.f} form module {self.f.__module__!r}."
344+
)
345+
else:
346+
onx, fct = None, None
347+
337348
values, kwargs = self.move_input_to_kwargs(values, kwargs)
338349
key = self.make_key(*values, **kwargs)
339350
if self.method_name_ is None and "method_name" in key:
340351
pos = list(key).index("method_name")
341352
self.method_name_ = key[pos + 1]
342353

343-
if key in self.versions:
354+
if onx is not None:
355+
# First jitting.
356+
self.versions[key] = fct
357+
self.onxs[key] = onx
358+
elif key in self.versions:
359+
# Already jitted.
344360
fct = self.versions[key]
345361
else:
362+
# One version was already jitted but types or parameter
363+
# are different.
346364
onx, fct = self.to_jit(*values, **kwargs)
347365
self.versions[key] = fct
348366
self.onxs[key] = onx
@@ -482,6 +500,9 @@ def _preprocess_constants(self, *args):
482500
# f"function {self.f} from module {self.f.__module__!r}."
483501
# )
484502
# new_args.append(n)
503+
elif isinstance(n, np.ndarray):
504+
new_args.append(self.tensor_class(n))
505+
modified = True
485506
elif isinstance(n, (int, float)):
486507
new_args.append(self.tensor_class(np.array(n)))
487508
modified = True
@@ -518,7 +539,8 @@ def __call__(self, *args, already_eager=False, **kwargs):
518539
map(
519540
lambda t: t is not None
520541
and not isinstance(
521-
t, (EagerTensor, Cst, int, float, tuple, slice, type)
542+
t,
543+
(EagerTensor, Cst, int, float, tuple, slice, type, np.ndarray),
522544
),
523545
args,
524546
)

onnx_array_api/npx/npx_tensors.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ def const_cast(self, to: Any = None) -> "EagerTensor":
3030
f"{self.__class__.__name__!r}."
3131
)
3232

33+
def __iter__(self):
34+
"""
35+
This is not implementation in the generic case.
36+
This method raises an exception with a better error message.
37+
"""
38+
raise RuntimeError(
39+
"Iterators are not implemented in the generic case. "
40+
"It may be enabled for the eager mode but it might fail "
41+
"when a whole function is converted into ONNX."
42+
)
43+
3344
@staticmethod
3445
def _op_impl(*inputs, method_name=None):
3546
# avoids circular imports.

onnx_array_api/npx/npx_var.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,16 @@ def to_onnx(
561561

562562
# Operators
563563

564+
def __iter__(self):
565+
"""
566+
This is not implementation in the generic case.
567+
This method raises an exception with a better error message.
568+
"""
569+
raise RuntimeError(
570+
"Iterators are not implemented in the generic case. "
571+
"Every function using them cannot be converted into ONNX."
572+
)
573+
564574
def _binary_op(self, ov: "Var", op_name: str, **kwargs) -> "Var":
565575
var = Var.get_cst_var()[1]
566576
if isinstance(ov, (int, float, np.ndarray, Cst)):

0 commit comments

Comments
 (0)